mirror of https://github.com/coqui-ai/TTS.git
Add conditional module on VITS
This commit is contained in:
parent
18dee1190e
commit
ef27039190
|
@ -123,6 +123,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
feat_latent_loss_alpha: float = 108.0
|
||||
pitch_loss_alpha: float = 5.0
|
||||
z_decoder_loss_alpha: float = 45.0
|
||||
conditional_module_loss_alpha: float = 45.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -598,6 +598,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.feat_latent_loss_alpha = c.feat_latent_loss_alpha
|
||||
self.gen_latent_loss_alpha = c.gen_latent_loss_alpha
|
||||
self.z_decoder_loss_alpha = c.z_decoder_loss_alpha
|
||||
self.conditional_module_loss_alpha = c.conditional_module_loss_alpha
|
||||
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
|
@ -684,6 +685,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
feats_disc_zp=None,
|
||||
pitch_loss=None,
|
||||
z_decoder_loss=None,
|
||||
conditional_module_loss=None,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -769,6 +771,11 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss += z_decoder_loss
|
||||
return_dict["z_decoder_loss"] = z_decoder_loss
|
||||
|
||||
if conditional_module_loss is not None:
|
||||
conditional_module_loss = conditional_module_loss * self.conditional_module_loss_alpha
|
||||
loss += conditional_module_loss
|
||||
return_dict["conditional_module_loss"] = conditional_module_loss
|
||||
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
|
|
|
@ -22,7 +22,7 @@ from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, F0Dataset
|
|||
from TTS.tts.layers.generic.classifier import ReversalClassifier
|
||||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder as ZDecoder
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder as forwardDecoder
|
||||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE, ResNetProsodyEncoder
|
||||
|
@ -677,6 +677,12 @@ class VitsArgs(Coqpit):
|
|||
use_noise_scale_predictor: bool = False
|
||||
use_latent_discriminator: bool = False
|
||||
|
||||
use_encoder_conditional_module: bool = False
|
||||
conditional_module_type: str = "fftransformer"
|
||||
conditional_module_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 3, "dropout_p": 0.1}
|
||||
)
|
||||
|
||||
# Pitch predictor
|
||||
use_pitch_on_enc_input: bool = False
|
||||
use_pitch: bool = False
|
||||
|
@ -782,7 +788,7 @@ class Vits(BaseTTS):
|
|||
self.args.dropout_p_text_encoder,
|
||||
language_emb_dim=self.embedded_language_dim,
|
||||
emotion_emb_dim=self.args.emotion_embedding_dim if not self.args.use_noise_scale_predictor else 0,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor else 0,
|
||||
prosody_emb_dim=self.args.prosody_embedding_dim if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else 0,
|
||||
pitch_dim=self.args.pitch_embedding_dim if self.args.use_pitch and self.args.use_pitch_on_enc_input else 0,
|
||||
)
|
||||
|
||||
|
@ -821,7 +827,7 @@ class Vits(BaseTTS):
|
|||
) and not self.args.use_noise_scale_predictor:
|
||||
dp_extra_inp_dim += self.args.emotion_embedding_dim
|
||||
|
||||
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor:
|
||||
if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module:
|
||||
dp_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
if self.args.use_pitch and self.args.use_pitch_on_enc_input:
|
||||
|
@ -855,13 +861,26 @@ class Vits(BaseTTS):
|
|||
if self.args.use_prosody_encoder:
|
||||
dec_extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
self.z_decoder = ZDecoder(
|
||||
self.z_decoder = forwardDecoder(
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels + dec_extra_inp_dim,
|
||||
self.args.z_decoder_type,
|
||||
self.args.z_decoder_params,
|
||||
)
|
||||
|
||||
|
||||
if self.args.use_encoder_conditional_module:
|
||||
extra_inp_dim = 0
|
||||
if self.args.use_prosody_encoder:
|
||||
extra_inp_dim += self.args.prosody_embedding_dim
|
||||
|
||||
self.encoder_conditional_module = forwardDecoder(
|
||||
self.args.hidden_channels,
|
||||
self.args.hidden_channels + extra_inp_dim,
|
||||
self.args.conditional_module_type,
|
||||
self.args.conditional_module_params,
|
||||
)
|
||||
|
||||
if self.args.use_pitch:
|
||||
if self.args.use_pitch_on_enc_input:
|
||||
self.pitch_predictor_vocab_emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
||||
|
@ -1495,7 +1514,7 @@ class Vits(BaseTTS):
|
|||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else None,
|
||||
pitch_emb=gt_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||
)
|
||||
|
||||
|
@ -1517,6 +1536,23 @@ class Vits(BaseTTS):
|
|||
|
||||
outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g_dp, lang_emb=lang_emb)
|
||||
|
||||
conditional_module_loss = None
|
||||
if self.args.use_encoder_conditional_module:
|
||||
g_cond = None
|
||||
cond_module_input = x
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_cond is None:
|
||||
g_cond = pros_emb
|
||||
else:
|
||||
g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
if g_cond is not None:
|
||||
cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1)
|
||||
|
||||
new_m_p = self.encoder_conditional_module(cond_module_input, x_mask)
|
||||
z_p_avg = average_over_durations(z_p, attn.sum(3).squeeze()).detach()
|
||||
conditional_module_loss = torch.nn.functional.l1_loss(new_m_p * x_mask, z_p_avg)
|
||||
|
||||
if self.args.use_pitch and not self.args.use_pitch_on_enc_input:
|
||||
pitch_loss, gt_avg_pitch_emb, _ = self.forward_pitch_predictor(m_p, x_lengths, pitch, attn.sum(3), g_dp)
|
||||
m_p = m_p + gt_avg_pitch_emb
|
||||
|
@ -1628,7 +1664,8 @@ class Vits(BaseTTS):
|
|||
"loss_text_enc_spk_rev_classifier": l_text_speaker,
|
||||
"loss_text_enc_emo_classifier": l_text_emotion,
|
||||
"pitch_loss": pitch_loss,
|
||||
"z_decoder_loss": z_decoder_loss,
|
||||
"z_decoder_loss": z_decoder_loss,
|
||||
"conditional_module_loss": conditional_module_loss,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -1744,7 +1781,7 @@ class Vits(BaseTTS):
|
|||
x_lengths,
|
||||
lang_emb=lang_emb,
|
||||
emo_emb=eg if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None,
|
||||
pros_emb=pros_emb if not self.args.use_noise_scale_predictor and not self.args.use_encoder_conditional_module else None,
|
||||
pitch_emb=pred_avg_pitch_emb if self.args.use_pitch and self.args.use_pitch_on_enc_input else None,
|
||||
)
|
||||
|
||||
|
@ -1778,6 +1815,18 @@ class Vits(BaseTTS):
|
|||
_, _, pred_avg_pitch_emb = self.forward_pitch_predictor(m_p, x_lengths, g_pp=g_dp, pitch_transform=pitch_transform)
|
||||
m_p = m_p + pred_avg_pitch_emb
|
||||
|
||||
if self.args.use_encoder_conditional_module:
|
||||
g_cond = None
|
||||
cond_module_input = x
|
||||
if self.args.use_prosody_encoder:
|
||||
if g_cond is None:
|
||||
g_cond = pros_emb
|
||||
else:
|
||||
g_cond = torch.cat([g_cond, pros_emb], dim=1) # [b, h1+h2, 1]
|
||||
|
||||
if g_cond is not None:
|
||||
cond_module_input = torch.cat((cond_module_input, g_cond.expand(-1, -1, cond_module_input.size(2))), dim=1)
|
||||
m_p = self.encoder_conditional_module(cond_module_input, x_mask)
|
||||
|
||||
m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2)
|
||||
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||
|
@ -1843,9 +1892,13 @@ class Vits(BaseTTS):
|
|||
|
||||
def compute_style_feature(self, style_wav_path):
|
||||
style_wav, sr = torchaudio.load(style_wav_path)
|
||||
if sr != self.config.audio.sample_rate:
|
||||
if sr != self.config.audio.sample_rate and self.args.encoder_sample_rate is None:
|
||||
raise RuntimeError(
|
||||
" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
|
||||
f" [!] Style reference need to have sampling rate equal to {self.config.audio.sample_rate} !!"
|
||||
)
|
||||
elif self.args.encoder_sample_rate is not None and sr != self.args.encoder_sample_rate:
|
||||
raise RuntimeError(
|
||||
f" [!] Style reference need to have sampling rate equal to {self.args.encoder_sample_rate} !!"
|
||||
)
|
||||
y = wav_to_spec(
|
||||
style_wav.unsqueeze(1),
|
||||
|
@ -2047,6 +2100,7 @@ class Vits(BaseTTS):
|
|||
feats_disc_zp=feats_disc_zp,
|
||||
pitch_loss=self.model_outputs_cache["pitch_loss"],
|
||||
z_decoder_loss=self.model_outputs_cache["z_decoder_loss"],
|
||||
conditional_module_loss=self.model_outputs_cache["conditional_module_loss"]
|
||||
)
|
||||
|
||||
return self.model_outputs_cache, loss_dict
|
||||
|
|
|
@ -33,6 +33,7 @@ def run_model_torch(
|
|||
emotion_embedding: torch.Tensor = None,
|
||||
style_speaker_id: torch.Tensor = None,
|
||||
style_speaker_d_vector: torch.Tensor = None,
|
||||
pitch_transform: torch.Tensor = None,
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
|
@ -53,6 +54,7 @@ def run_model_torch(
|
|||
_func = model.inference
|
||||
outputs = _func(
|
||||
inputs,
|
||||
pitch_transform=pitch_transform,
|
||||
aux_input={
|
||||
"x_lengths": input_lengths,
|
||||
"speaker_ids": speaker_id,
|
||||
|
@ -134,6 +136,7 @@ def synthesis(
|
|||
emotion_embedding=None,
|
||||
style_speaker_id=None,
|
||||
style_speaker_d_vector=None,
|
||||
pitch_transform=None,
|
||||
):
|
||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||
the vocoder model.
|
||||
|
@ -243,6 +246,7 @@ def synthesis(
|
|||
emotion_embedding=emotion_embedding,
|
||||
style_speaker_id=style_speaker_id,
|
||||
style_speaker_d_vector=style_speaker_d_vector,
|
||||
pitch_transform=pitch_transform,
|
||||
)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
|
|
|
@ -215,6 +215,7 @@ class Synthesizer(object):
|
|||
reference_speaker_name=None,
|
||||
emotion_name=None,
|
||||
style_speaker_name=None,
|
||||
pitch_transform=None,
|
||||
) -> List[int]:
|
||||
"""🐸 TTS magic. Run all the models and generate speech.
|
||||
|
||||
|
@ -381,6 +382,7 @@ class Synthesizer(object):
|
|||
language_id=language_id,
|
||||
emotion_embedding=emotion_embedding,
|
||||
emotion_id=emotion_id,
|
||||
pitch_transform=pitch_transform,
|
||||
)
|
||||
waveform = outputs["wav"]
|
||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
||||
|
|
|
@ -31,7 +31,7 @@ config = VitsConfig(
|
|||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
||||
config.model_args.use_z_decoder = True
|
||||
config.model_args.use_z_decoder = False
|
||||
|
||||
# active multispeaker d-vec mode
|
||||
config.model_args.use_d_vector_file = True
|
||||
|
|
|
@ -45,11 +45,14 @@ config.model_args.d_vector_dim = 128
|
|||
# prosody embedding
|
||||
config.model_args.use_prosody_encoder = True
|
||||
config.model_args.prosody_embedding_dim = 64
|
||||
|
||||
config.model_args.use_encoder_conditional_module = True
|
||||
|
||||
# active classifier
|
||||
config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json"
|
||||
config.model_args.use_prosody_enc_emo_classifier = False
|
||||
config.model_args.use_text_enc_emo_classifier = False
|
||||
config.model_args.use_prosody_encoder_z_p_input = True
|
||||
config.model_args.use_prosody_encoder_z_p_input = False
|
||||
|
||||
config.model_args.prosody_encoder_type = "gst"
|
||||
config.model_args.detach_prosody_enc_input = True
|
||||
|
@ -64,7 +67,7 @@ config.model_args.use_prosody_embedding_squeezer = False
|
|||
config.model_args.prosody_embedding_squeezer_input_dim = 0
|
||||
|
||||
# pitch predictor
|
||||
config.model_args.use_pitch = True
|
||||
config.model_args.use_pitch = False
|
||||
config.model_args.use_pitch_on_enc_input = False
|
||||
config.model_args.condition_dp_on_speaker = False
|
||||
|
||||
|
|
Loading…
Reference in New Issue