Add Determinist decoder on VITS

This commit is contained in:
Edresson Casanova 2022-06-20 14:06:16 +00:00
parent 3165c55fae
commit 18dee1190e
5 changed files with 98 additions and 6 deletions

View File

@ -122,6 +122,7 @@ class VitsConfig(BaseTTSConfig):
gen_latent_loss_alpha: float = 5.0 gen_latent_loss_alpha: float = 5.0
feat_latent_loss_alpha: float = 108.0 feat_latent_loss_alpha: float = 108.0
pitch_loss_alpha: float = 5.0 pitch_loss_alpha: float = 5.0
z_decoder_loss_alpha: float = 45.0
# data loader params # data loader params
return_wav: bool = True return_wav: bool = True

View File

@ -597,6 +597,7 @@ class VitsGeneratorLoss(nn.Module):
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
self.feat_latent_loss_alpha = c.feat_latent_loss_alpha self.feat_latent_loss_alpha = c.feat_latent_loss_alpha
self.gen_latent_loss_alpha = c.gen_latent_loss_alpha self.gen_latent_loss_alpha = c.gen_latent_loss_alpha
self.z_decoder_loss_alpha = c.z_decoder_loss_alpha
self.stft = TorchSTFT( self.stft = TorchSTFT(
c.audio.fft_size, c.audio.fft_size,
@ -682,6 +683,7 @@ class VitsGeneratorLoss(nn.Module):
feats_disc_mp=None, feats_disc_mp=None,
feats_disc_zp=None, feats_disc_zp=None,
pitch_loss=None, pitch_loss=None,
z_decoder_loss=None,
): ):
""" """
Shapes: Shapes:
@ -762,6 +764,11 @@ class VitsGeneratorLoss(nn.Module):
loss += pitch_loss loss += pitch_loss
return_dict["pitch_loss"] = pitch_loss return_dict["pitch_loss"] = pitch_loss
if z_decoder_loss is not None:
z_decoder_loss = z_decoder_loss * self.z_decoder_loss_alpha
loss += z_decoder_loss
return_dict["z_decoder_loss"] = z_decoder_loss
# pass losses to the dict # pass losses to the dict
return_dict["loss_gen"] = loss_gen return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl return_dict["loss_kl"] = loss_kl

View File

@ -4,7 +4,7 @@ import numpy as np
import pyworld as pw import pyworld as pw
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Dict, List, Tuple, Union from typing import Callable, Dict, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -22,6 +22,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
from TTS.tts.layers.generic.classifier import ReversalClassifier from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.feed_forward.decoder import Decoder as ZDecoder
from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
@ -687,6 +688,14 @@ class VitsArgs(Coqpit):
use_precomputed_alignments: bool = False use_precomputed_alignments: bool = False
alignments_cache_path: str = "" alignments_cache_path: str = ""
pitch_embedding_dim: int = 0 pitch_embedding_dim: int = 0
pitch_mean: float = 0.0
pitch_std: float = 0.0
use_z_decoder: bool = False
z_decoder_type: str = "fftransformer"
z_decoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
)
detach_dp_input: bool = True detach_dp_input: bool = True
use_language_embedding: bool = False use_language_embedding: bool = False
@ -838,6 +847,21 @@ class Vits(BaseTTS):
language_emb_dim=self.embedded_language_dim, language_emb_dim=self.embedded_language_dim,
) )
if self.args.use_z_decoder:
dec_extra_inp_dim = self.cond_embedding_dim
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
dec_extra_inp_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder:
dec_extra_inp_dim += self.args.prosody_embedding_dim
self.z_decoder = ZDecoder(
self.args.hidden_channels,
self.args.hidden_channels + dec_extra_inp_dim,
self.args.z_decoder_type,
self.args.z_decoder_params,
)
if self.args.use_pitch: if self.args.use_pitch:
if self.args.use_pitch_on_enc_input: if self.args.use_pitch_on_enc_input:
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
@ -1232,6 +1256,7 @@ class Vits(BaseTTS):
pitch: torch.FloatTensor = None, pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None, dr: torch.IntTensor = None,
g_pp: torch.IntTensor = None, g_pp: torch.IntTensor = None,
pitch_transform: Callable=None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Pitch predictor forward pass. """Pitch predictor forward pass.
@ -1245,6 +1270,7 @@ class Vits(BaseTTS):
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None. g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
pitch_transform (Callable, optional): Pitch transform function. Defaults to None.
Returns: Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
@ -1267,6 +1293,9 @@ class Vits(BaseTTS):
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
) )
if pitch_transform is not None:
pred_avg_pitch = pitch_transform(pred_avg_pitch, x_mask.sum(dim=(1,2)), self.args.pitch_mean, self.args.pitch_std)
pitch_loss = None pitch_loss = None
pred_avg_pitch_emb = None pred_avg_pitch_emb = None
gt_avg_pitch_emb = None gt_avg_pitch_emb = None
@ -1500,6 +1529,30 @@ class Vits(BaseTTS):
# expand prior # expand prior
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
z_decoder_loss = None
if self.args.use_z_decoder:
x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, x])
# prepare the conditional emb
g_dec = g
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dec is None:
g_dec = eg
else:
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
if self.args.use_prosody_encoder:
if g_dec is None:
g_dec = pros_emb
else:
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
if g_dec is not None:
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
# decoder pass
z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec)
z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z)
if self.args.use_noise_scale_predictor: if self.args.use_noise_scale_predictor:
nsp_input = torch.transpose(m_p_expanded, 1, -1) nsp_input = torch.transpose(m_p_expanded, 1, -1)
if self.args.use_prosody_encoder and pros_emb is not None: if self.args.use_prosody_encoder and pros_emb is not None:
@ -1575,6 +1628,7 @@ class Vits(BaseTTS):
"loss_text_enc_spk_rev_classifier": l_text_speaker, "loss_text_enc_spk_rev_classifier": l_text_speaker,
"loss_text_enc_emo_classifier": l_text_emotion, "loss_text_enc_emo_classifier": l_text_emotion,
"pitch_loss": pitch_loss, "pitch_loss": pitch_loss,
"z_decoder_loss": z_decoder_loss,
} }
) )
return outputs return outputs
@ -1598,6 +1652,7 @@ class Vits(BaseTTS):
"emotion_ids": None, "emotion_ids": None,
"style_feature": None, "style_feature": None,
}, },
pitch_transform=None,
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
""" """
Note: Note:
@ -1682,7 +1737,7 @@ class Vits(BaseTTS):
pred_avg_pitch_emb = None pred_avg_pitch_emb = None
if self.args.use_pitch and self.args.use_pitch_on_enc_input: if self.args.use_pitch and self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp) _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
x, m_p, logs_p, x_mask = self.text_encoder( x, m_p, logs_p, x_mask = self.text_encoder(
x, x,
@ -1720,7 +1775,7 @@ class Vits(BaseTTS):
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
if self.args.use_pitch and not self.args.use_pitch_on_enc_input: if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp) _, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
m_p = m_p + pred_avg_pitch_emb m_p = m_p + pred_avg_pitch_emb
@ -1745,7 +1800,28 @@ class Vits(BaseTTS):
else: else:
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True) if self.args.use_z_decoder:
x_expanded = torch.matmul(attn.transpose(1, 2), x.transpose(1, 2)).transpose(1, 2)
# prepare the conditional emb
g_dec = g
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dec is None:
g_dec = eg
else:
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
if self.args.use_prosody_encoder:
if g_dec is None:
g_dec = pros_emb
else:
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
if g_dec is not None:
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
# decoder pass
z = self.z_decoder(x_expanded, y_mask, g=g_dec)
else:
z = self.flow(z_p, y_mask, g=g, reverse=True)
# upsampling if needed # upsampling if needed
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
@ -1892,7 +1968,7 @@ class Vits(BaseTTS):
outputs["model_outputs"].detach(), outputs["model_outputs"].detach(),
outputs["waveform_seg"], outputs["waveform_seg"],
outputs["m_p_unexpanded"].detach(), outputs["m_p_unexpanded"].detach(),
outputs["z_p_avg"].detach(), outputs["z_p_avg"].detach() if outputs["z_p_avg"] is not None else None,
) )
# compute loss # compute loss
@ -1940,7 +2016,7 @@ class Vits(BaseTTS):
self.model_outputs_cache["model_outputs"], self.model_outputs_cache["model_outputs"],
self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["waveform_seg"],
self.model_outputs_cache["m_p_unexpanded"], self.model_outputs_cache["m_p_unexpanded"],
self.model_outputs_cache["z_p_avg"].detach(), self.model_outputs_cache["z_p_avg"].detach() if self.model_outputs_cache["z_p_avg"] is not None else None,
) )
# compute losses # compute losses
@ -1970,6 +2046,7 @@ class Vits(BaseTTS):
feats_disc_mp=feats_disc_mp, feats_disc_mp=feats_disc_mp,
feats_disc_zp=feats_disc_zp, feats_disc_zp=feats_disc_zp,
pitch_loss=self.model_outputs_cache["pitch_loss"], pitch_loss=self.model_outputs_cache["pitch_loss"],
z_decoder_loss=self.model_outputs_cache["z_decoder_loss"],
) )
return self.model_outputs_cache, loss_dict return self.model_outputs_cache, loss_dict
@ -2337,6 +2414,10 @@ class Vits(BaseTTS):
# sort input sequences from short to long # sort input sequences from short to long
dataset.preprocess_samples() dataset.preprocess_samples()
if self.args.use_pitch:
self.args.pitch_mean = dataset.f0_dataset.mean
self.args.pitch_std = dataset.f0_dataset.std
# get samplers # get samplers
sampler = self.get_sampler(config, dataset, num_gpus) sampler = self.get_sampler(config, dataset, num_gpus)

View File

@ -31,6 +31,8 @@ config = VitsConfig(
config.audio.do_trim_silence = True config.audio.do_trim_silence = True
config.audio.trim_db = 60 config.audio.trim_db = 60
config.model_args.use_z_decoder = True
# active multispeaker d-vec mode # active multispeaker d-vec mode
config.model_args.use_d_vector_file = True config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"

View File

@ -79,6 +79,7 @@ out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1" speaker_id = "ljspeech-1"
continue_speakers_path = os.path.join(continue_path, "speakers.json") continue_speakers_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} " inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
run_cli(inference_command) run_cli(inference_command)