Add Determinist decoder on VITS

This commit is contained in:
Edresson Casanova 2022-06-20 14:06:16 +00:00
parent 3165c55fae
commit 18dee1190e
5 changed files with 98 additions and 6 deletions

View File

@ -122,6 +122,7 @@ class VitsConfig(BaseTTSConfig):
gen_latent_loss_alpha: float = 5.0
feat_latent_loss_alpha: float = 108.0
pitch_loss_alpha: float = 5.0
z_decoder_loss_alpha: float = 45.0
# data loader params
return_wav: bool = True

View File

@ -597,6 +597,7 @@ class VitsGeneratorLoss(nn.Module):
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
self.feat_latent_loss_alpha = c.feat_latent_loss_alpha
self.gen_latent_loss_alpha = c.gen_latent_loss_alpha
self.z_decoder_loss_alpha = c.z_decoder_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
@ -682,6 +683,7 @@ class VitsGeneratorLoss(nn.Module):
feats_disc_mp=None,
feats_disc_zp=None,
pitch_loss=None,
z_decoder_loss=None,
):
"""
Shapes:
@ -762,6 +764,11 @@ class VitsGeneratorLoss(nn.Module):
loss += pitch_loss
return_dict["pitch_loss"] = pitch_loss
if z_decoder_loss is not None:
z_decoder_loss = z_decoder_loss * self.z_decoder_loss_alpha
loss += z_decoder_loss
return_dict["z_decoder_loss"] = z_decoder_loss
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl

View File

@ -4,7 +4,7 @@ import numpy as np
import pyworld as pw
from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Dict, List, Tuple, Union
from typing import Callable, Dict, List, Tuple, Union
import torch
import torch.distributed as dist
@ -22,6 +22,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
from TTS.tts.layers.generic.classifier import ReversalClassifier
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.feed_forward.decoder import Decoder as ZDecoder
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
@ -687,6 +688,14 @@ class VitsArgs(Coqpit):
use_precomputed_alignments: bool = False
alignments_cache_path: str = ""
pitch_embedding_dim: int = 0
pitch_mean: float = 0.0
pitch_std: float = 0.0
use_z_decoder: bool = False
z_decoder_type: str = "fftransformer"
z_decoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
)
detach_dp_input: bool = True
use_language_embedding: bool = False
@ -838,6 +847,21 @@ class Vits(BaseTTS):
language_emb_dim=self.embedded_language_dim,
)
if self.args.use_z_decoder:
dec_extra_inp_dim = self.cond_embedding_dim
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
dec_extra_inp_dim += self.args.emotion_embedding_dim
if self.args.use_prosody_encoder:
dec_extra_inp_dim += self.args.prosody_embedding_dim
self.z_decoder = ZDecoder(
self.args.hidden_channels,
self.args.hidden_channels + dec_extra_inp_dim,
self.args.z_decoder_type,
self.args.z_decoder_params,
)
if self.args.use_pitch:
if self.args.use_pitch_on_enc_input:
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
@ -1232,6 +1256,7 @@ class Vits(BaseTTS):
pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None,
g_pp: torch.IntTensor = None,
pitch_transform: Callable=None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Pitch predictor forward pass.
@ -1245,6 +1270,7 @@ class Vits(BaseTTS):
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
pitch_transform (Callable, optional): Pitch transform function. Defaults to None.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
@ -1267,6 +1293,9 @@ class Vits(BaseTTS):
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
)
if pitch_transform is not None:
pred_avg_pitch = pitch_transform(pred_avg_pitch, x_mask.sum(dim=(1,2)), self.args.pitch_mean, self.args.pitch_std)
pitch_loss = None
pred_avg_pitch_emb = None
gt_avg_pitch_emb = None
@ -1500,6 +1529,30 @@ class Vits(BaseTTS):
# expand prior
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
z_decoder_loss = None
if self.args.use_z_decoder:
x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, x])
# prepare the conditional emb
g_dec = g
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dec is None:
g_dec = eg
else:
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
if self.args.use_prosody_encoder:
if g_dec is None:
g_dec = pros_emb
else:
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
if g_dec is not None:
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
# decoder pass
z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec)
z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z)
if self.args.use_noise_scale_predictor:
nsp_input = torch.transpose(m_p_expanded, 1, -1)
if self.args.use_prosody_encoder and pros_emb is not None:
@ -1575,6 +1628,7 @@ class Vits(BaseTTS):
"loss_text_enc_spk_rev_classifier": l_text_speaker,
"loss_text_enc_emo_classifier": l_text_emotion,
"pitch_loss": pitch_loss,
"z_decoder_loss": z_decoder_loss,
}
)
return outputs
@ -1598,6 +1652,7 @@ class Vits(BaseTTS):
"emotion_ids": None,
"style_feature": None,
},
pitch_transform=None,
): # pylint: disable=dangerous-default-value
"""
Note:
@ -1682,7 +1737,7 @@ class Vits(BaseTTS):
pred_avg_pitch_emb = None
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp)
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
x, m_p, logs_p, x_mask = self.text_encoder(
x,
@ -1720,7 +1775,7 @@ class Vits(BaseTTS):
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp)
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
m_p = m_p + pred_avg_pitch_emb
@ -1745,7 +1800,28 @@ class Vits(BaseTTS):
else:
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
z = self.flow(z_p, y_mask, g=g, reverse=True)
if self.args.use_z_decoder:
x_expanded = torch.matmul(attn.transpose(1, 2), x.transpose(1, 2)).transpose(1, 2)
# prepare the conditional emb
g_dec = g
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
if g_dec is None:
g_dec = eg
else:
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
if self.args.use_prosody_encoder:
if g_dec is None:
g_dec = pros_emb
else:
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
if g_dec is not None:
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
# decoder pass
z = self.z_decoder(x_expanded, y_mask, g=g_dec)
else:
z = self.flow(z_p, y_mask, g=g, reverse=True)
# upsampling if needed
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
@ -1892,7 +1968,7 @@ class Vits(BaseTTS):
outputs["model_outputs"].detach(),
outputs["waveform_seg"],
outputs["m_p_unexpanded"].detach(),
outputs["z_p_avg"].detach(),
outputs["z_p_avg"].detach() if outputs["z_p_avg"] is not None else None,
)
# compute loss
@ -1940,7 +2016,7 @@ class Vits(BaseTTS):
self.model_outputs_cache["model_outputs"],
self.model_outputs_cache["waveform_seg"],
self.model_outputs_cache["m_p_unexpanded"],
self.model_outputs_cache["z_p_avg"].detach(),
self.model_outputs_cache["z_p_avg"].detach() if self.model_outputs_cache["z_p_avg"] is not None else None,
)
# compute losses
@ -1970,6 +2046,7 @@ class Vits(BaseTTS):
feats_disc_mp=feats_disc_mp,
feats_disc_zp=feats_disc_zp,
pitch_loss=self.model_outputs_cache["pitch_loss"],
z_decoder_loss=self.model_outputs_cache["z_decoder_loss"],
)
return self.model_outputs_cache, loss_dict
@ -2337,6 +2414,10 @@ class Vits(BaseTTS):
# sort input sequences from short to long
dataset.preprocess_samples()
if self.args.use_pitch:
self.args.pitch_mean = dataset.f0_dataset.mean
self.args.pitch_std = dataset.f0_dataset.std
# get samplers
sampler = self.get_sampler(config, dataset, num_gpus)

View File

@ -31,6 +31,8 @@ config = VitsConfig(
config.audio.do_trim_silence = True
config.audio.trim_db = 60
config.model_args.use_z_decoder = True
# active multispeaker d-vec mode
config.model_args.use_d_vector_file = True
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"

View File

@ -79,6 +79,7 @@ out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
run_cli(inference_command)