mirror of https://github.com/coqui-ai/TTS.git
Add Determinist decoder on VITS
This commit is contained in:
parent
3165c55fae
commit
18dee1190e
|
@ -122,6 +122,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
gen_latent_loss_alpha: float = 5.0
|
||||
feat_latent_loss_alpha: float = 108.0
|
||||
pitch_loss_alpha: float = 5.0
|
||||
z_decoder_loss_alpha: float = 45.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -597,6 +597,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.prosody_encoder_kl_loss_alpha = c.prosody_encoder_kl_loss_alpha
|
||||
self.feat_latent_loss_alpha = c.feat_latent_loss_alpha
|
||||
self.gen_latent_loss_alpha = c.gen_latent_loss_alpha
|
||||
self.z_decoder_loss_alpha = c.z_decoder_loss_alpha
|
||||
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
|
@ -682,6 +683,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
feats_disc_mp=None,
|
||||
feats_disc_zp=None,
|
||||
pitch_loss=None,
|
||||
z_decoder_loss=None,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -762,6 +764,11 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss += pitch_loss
|
||||
return_dict["pitch_loss"] = pitch_loss
|
||||
|
||||
if z_decoder_loss is not None:
|
||||
z_decoder_loss = z_decoder_loss * self.z_decoder_loss_alpha
|
||||
loss += z_decoder_loss
|
||||
return_dict["z_decoder_loss"] = z_decoder_loss
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||
import pyworld as pw
|
||||
from dataclasses import dataclass, field, replace
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -22,6 +22,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
|
|||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder as ZDecoder
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
|
||||
|
@ -687,6 +688,14 @@ class VitsArgs(Coqpit):
|
|||
use_precomputed_alignments: bool = False
|
||||
alignments_cache_path: str = ""
|
||||
pitch_embedding_dim: int = 0
|
||||
pitch_mean: float = 0.0
|
||||
pitch_std: float = 0.0
|
||||
|
||||
use_z_decoder: bool = False
|
||||
z_decoder_type: str = "fftransformer"
|
||||
z_decoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||
)
|
||||
|
||||
detach_dp_input: bool = True
|
||||
use_language_embedding: bool = False
|
||||
|
@ -838,6 +847,21 @@ class Vits(BaseTTS):
|
|||
language_emb_dim=self.embedded_language_dim,
|
||||
)
|
||||
|
||||
if self.args.use_z_decoder:
|
||||
dec_extra_inp_dim = self.cond_embedding_dim
|
||||
if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings:
|
||||
dec_extra_inp_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder:
|
||||
dec_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
self.z_decoder = ZDecoder(
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels + dec_extra_inp_dim,
|
||||
self.args.z_decoder_type,
|
||||
self.args.z_decoder_params,
|
||||
)
|
||||
|
||||
if self.args.use_pitch:
|
||||
if self.args.use_pitch_on_enc_input:
|
||||
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
||||
|
@ -1232,6 +1256,7 @@ class Vits(BaseTTS):
|
|||
pitch: torch.FloatTensor = None,
|
||||
dr: torch.IntTensor = None,
|
||||
g_pp: torch.IntTensor = None,
|
||||
pitch_transform: Callable=None,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Pitch predictor forward pass.
|
||||
|
||||
|
@ -1245,6 +1270,7 @@ class Vits(BaseTTS):
|
|||
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
||||
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||
g_pp (torch.IntTensor, optional): Speaker/prosody embedding to condition the pithc predictor. Defaults to None.
|
||||
pitch_transform (Callable, optional): Pitch transform function. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
||||
|
@ -1267,6 +1293,9 @@ class Vits(BaseTTS):
|
|||
g=g_pp.detach() if self.args.detach_pp_input and g_pp is not None else g_pp
|
||||
)
|
||||
|
||||
if pitch_transform is not None:
|
||||
pred_avg_pitch = pitch_transform(pred_avg_pitch, x_mask.sum(dim=(1,2)), self.args.pitch_mean, self.args.pitch_std)
|
||||
|
||||
pitch_loss = None
|
||||
pred_avg_pitch_emb = None
|
||||
gt_avg_pitch_emb = None
|
||||
|
@ -1500,6 +1529,30 @@ class Vits(BaseTTS):
|
|||
# expand prior
|
||||
m_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
logs_p_expanded = torch.einsum("klmn, kjm -> kjn", [attn, logs_p])
|
||||
|
||||
z_decoder_loss = None
|
||||
if self.args.use_z_decoder:
|
||||
x_expanded = torch.einsum("klmn, kjm -> kjn", [attn, x])
|
||||
# prepare the conditional emb
|
||||
g_dec = g
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dec is None:
|
||||
g_dec = eg
|
||||
else:
|
||||
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_dec is None:
|
||||
g_dec = pros_emb
|
||||
else:
|
||||
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
if g_dec is not None:
|
||||
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
|
||||
|
||||
# decoder pass
|
||||
z_decoder = self.z_decoder(x_expanded, y_mask, g=g_dec)
|
||||
z_decoder_loss = torch.nn.functional.l1_loss(z_decoder * y_mask, z)
|
||||
|
||||
if self.args.use_noise_scale_predictor:
|
||||
nsp_input = torch.transpose(m_p_expanded, 1, -1)
|
||||
if self.args.use_prosody_encoder and pros_emb is not None:
|
||||
|
@ -1575,6 +1628,7 @@ class Vits(BaseTTS):
|
|||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||
"pitch_loss": pitch_loss,
|
||||
"z_decoder_loss": z_decoder_loss,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -1598,6 +1652,7 @@ class Vits(BaseTTS):
|
|||
"emotion_ids": None,
|
||||
"style_feature": None,
|
||||
},
|
||||
pitch_transform=None,
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Note:
|
||||
|
@ -1682,7 +1737,7 @@ class Vits(BaseTTS):
|
|||
|
||||
pred_avg_pitch_emb = None
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp)
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(x, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
|
||||
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(
|
||||
x,
|
||||
|
@ -1720,7 +1775,7 @@ class Vits(BaseTTS):
|
|||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2))
|
||||
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp)
|
||||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
|
||||
m_p = m_p + pred_avg_pitch_emb
|
||||
|
||||
|
||||
|
@ -1745,7 +1800,28 @@ class Vits(BaseTTS):
|
|||
else:
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
if self.args.use_z_decoder:
|
||||
x_expanded = torch.matmul(attn.transpose(1, 2), x.transpose(1, 2)).transpose(1, 2)
|
||||
# prepare the conditional emb
|
||||
g_dec = g
|
||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||
if g_dec is None:
|
||||
g_dec = eg
|
||||
else:
|
||||
g_dec = torch.cat([g_dec, eg], dim=1) # [b, h1+h2, 1]
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_dec is None:
|
||||
g_dec = pros_emb
|
||||
else:
|
||||
g_dec = torch.cat([g_dec, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
if g_dec is not None:
|
||||
x_expanded = torch.cat((x_expanded, g_dec.expand(-1, -1, x_expanded.size(2))), dim=1)
|
||||
|
||||
# decoder pass
|
||||
z = self.z_decoder(x_expanded, y_mask, g=g_dec)
|
||||
else:
|
||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||
|
||||
# upsampling if needed
|
||||
z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask)
|
||||
|
@ -1892,7 +1968,7 @@ class Vits(BaseTTS):
|
|||
outputs["model_outputs"].detach(),
|
||||
outputs["waveform_seg"],
|
||||
outputs["m_p_unexpanded"].detach(),
|
||||
outputs["z_p_avg"].detach(),
|
||||
outputs["z_p_avg"].detach() if outputs["z_p_avg"] is not None else None,
|
||||
)
|
||||
|
||||
# compute loss
|
||||
|
@ -1940,7 +2016,7 @@ class Vits(BaseTTS):
|
|||
self.model_outputs_cache["model_outputs"],
|
||||
self.model_outputs_cache["waveform_seg"],
|
||||
self.model_outputs_cache["m_p_unexpanded"],
|
||||
self.model_outputs_cache["z_p_avg"].detach(),
|
||||
self.model_outputs_cache["z_p_avg"].detach() if self.model_outputs_cache["z_p_avg"] is not None else None,
|
||||
)
|
||||
|
||||
# compute losses
|
||||
|
@ -1970,6 +2046,7 @@ class Vits(BaseTTS):
|
|||
feats_disc_mp=feats_disc_mp,
|
||||
feats_disc_zp=feats_disc_zp,
|
||||
pitch_loss=self.model_outputs_cache["pitch_loss"],
|
||||
z_decoder_loss=self.model_outputs_cache["z_decoder_loss"],
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
@ -2337,6 +2414,10 @@ class Vits(BaseTTS):
|
|||
# sort input sequences from short to long
|
||||
dataset.preprocess_samples()
|
||||
|
||||
if self.args.use_pitch:
|
||||
self.args.pitch_mean = dataset.f0_dataset.mean
|
||||
self.args.pitch_std = dataset.f0_dataset.std
|
||||
|
||||
# get samplers
|
||||
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||
|
||||
|
|
|
@ -31,6 +31,8 @@ config = VitsConfig(
|
|||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
config.model_args.use_z_decoder = True
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_d_vector_file = True
|
||||
config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json"
|
||||
|
|
|
@ -79,6 +79,7 @@ out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
|||
speaker_id = "ljspeech-1"
|
||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||
|
||||
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path} "
|
||||
run_cli(inference_command)
|
||||
|
||||
|
|
Loading…
Reference in New Issue