mirror of https://github.com/coqui-ai/TTS.git
Support prosody conditional model on decoder input
This commit is contained in:
parent
02194367d7
commit
512525cc39
|
@ -50,16 +50,14 @@ class ReversalClassifier(nn.Module):
|
||||||
return x, loss
|
return x, loss
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def loss(labels, predictions, x_mask):
|
def loss(labels, predictions, x_mask=None, ignore_index=-100):
|
||||||
ignore_index = -100
|
|
||||||
if x_mask is None:
|
if x_mask is None:
|
||||||
x_mask = torch.Tensor([predictions.size(1)]).repeat(predictions.size(0)).int().to(predictions.device)
|
x_mask = torch.Tensor([predictions.size(1)]).repeat(predictions.size(0)).int().to(predictions.device)
|
||||||
|
|
||||||
ml = torch.max(x_mask)
|
ml = torch.max(x_mask)
|
||||||
input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None]
|
input_mask = torch.arange(ml, device=predictions.device)[None, :] < x_mask[:, None]
|
||||||
|
else:
|
||||||
target = labels.repeat(ml.int().item(), 1).transpose(0, 1)
|
input_mask = x_mask.squeeze().bool()
|
||||||
|
target = labels.repeat(input_mask.size(-1), 1).transpose(0, 1).int().long()
|
||||||
target[~input_mask] = ignore_index
|
target[~input_mask] = ignore_index
|
||||||
|
|
||||||
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)
|
return nn.functional.cross_entropy(predictions.transpose(1, 2), target, ignore_index=ignore_index)
|
||||||
|
|
||||||
|
|
|
@ -552,6 +552,7 @@ class VitsArgs(Coqpit):
|
||||||
use_prosody_enc_emo_classifier: bool = False
|
use_prosody_enc_emo_classifier: bool = False
|
||||||
|
|
||||||
use_prosody_conditional_flow_module: bool = False
|
use_prosody_conditional_flow_module: bool = False
|
||||||
|
prosody_conditional_flow_module_on_decoder: bool = False
|
||||||
|
|
||||||
detach_dp_input: bool = True
|
detach_dp_input: bool = True
|
||||||
use_language_embedding: bool = False
|
use_language_embedding: bool = False
|
||||||
|
@ -1150,15 +1151,6 @@ class Vits(BaseTTS):
|
||||||
pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None
|
pros_emb=pros_emb if not self.args.use_prosody_conditional_flow_module else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# conditional module
|
|
||||||
if self.args.use_prosody_conditional_flow_module:
|
|
||||||
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
|
||||||
|
|
||||||
# reversal speaker loss to force the encoder to be speaker identity free
|
|
||||||
l_text_emotion = None
|
|
||||||
if self.args.use_text_enc_emo_classifier:
|
|
||||||
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=None)
|
|
||||||
|
|
||||||
# reversal speaker loss to force the encoder to be speaker identity free
|
# reversal speaker loss to force the encoder to be speaker identity free
|
||||||
l_text_speaker = None
|
l_text_speaker = None
|
||||||
if self.args.use_text_enc_spk_reversal_classifier:
|
if self.args.use_text_enc_spk_reversal_classifier:
|
||||||
|
@ -1167,6 +1159,24 @@ class Vits(BaseTTS):
|
||||||
# flow layers
|
# flow layers
|
||||||
z_p = self.flow(z, y_mask, g=g)
|
z_p = self.flow(z, y_mask, g=g)
|
||||||
|
|
||||||
|
# reversal speaker loss to force the encoder to be speaker identity free
|
||||||
|
l_text_emotion = None
|
||||||
|
if self.args.use_text_enc_emo_classifier:
|
||||||
|
if self.args.prosody_conditional_flow_module_on_decoder:
|
||||||
|
_, l_text_emotion = self.emo_text_enc_classifier(z_p.transpose(1, 2), eid, x_mask=y_mask)
|
||||||
|
|
||||||
|
# conditional module
|
||||||
|
if self.args.use_prosody_conditional_flow_module:
|
||||||
|
if not self.args.prosody_conditional_flow_module_on_decoder:
|
||||||
|
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
||||||
|
else:
|
||||||
|
z_p = self.prosody_conditional_module(z_p, y_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
||||||
|
|
||||||
|
# reversal speaker loss to force the encoder to be speaker identity free
|
||||||
|
if self.args.use_text_enc_emo_classifier:
|
||||||
|
if not self.args.prosody_conditional_flow_module_on_decoder:
|
||||||
|
_, l_text_emotion = self.emo_text_enc_classifier(m_p.transpose(1, 2), eid, x_mask=x_mask)
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
g_dp = g if self.args.condition_dp_on_speaker else None
|
g_dp = g if self.args.condition_dp_on_speaker else None
|
||||||
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings):
|
||||||
|
@ -1318,6 +1328,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# conditional module
|
# conditional module
|
||||||
if self.args.use_prosody_conditional_flow_module:
|
if self.args.use_prosody_conditional_flow_module:
|
||||||
|
if not self.args.prosody_conditional_flow_module_on_decoder:
|
||||||
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
m_p = self.prosody_conditional_module(m_p, x_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb)
|
||||||
|
|
||||||
# duration predictor
|
# duration predictor
|
||||||
|
@ -1358,6 +1369,12 @@ class Vits(BaseTTS):
|
||||||
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale
|
||||||
|
|
||||||
|
# conditional module
|
||||||
|
if self.args.use_prosody_conditional_flow_module:
|
||||||
|
if self.args.prosody_conditional_flow_module_on_decoder:
|
||||||
|
z_p = self.prosody_conditional_module(z_p, y_mask, g=eg if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) else pros_emb, reverse=True)
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
||||||
|
|
||||||
# upsampling if needed
|
# upsampling if needed
|
||||||
|
@ -1394,7 +1411,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference_voice_conversion(
|
def inference_voice_conversion(
|
||||||
self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None
|
self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None, ref_emotion=None, target_emotion=None
|
||||||
):
|
):
|
||||||
"""Inference for voice conversion
|
"""Inference for voice conversion
|
||||||
|
|
||||||
|
@ -1417,10 +1434,11 @@ class Vits(BaseTTS):
|
||||||
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
|
speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector
|
||||||
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector
|
speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector
|
||||||
# print(y.shape, y_lengths.shape)
|
# print(y.shape, y_lengths.shape)
|
||||||
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt)
|
wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt, ref_emotion, target_emotion)
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt):
|
|
||||||
|
def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt, ref_emotion=None, target_emotion=None):
|
||||||
"""Forward pass for voice conversion
|
"""Forward pass for voice conversion
|
||||||
|
|
||||||
TODO: create an end-point for voice conversion
|
TODO: create an end-point for voice conversion
|
||||||
|
@ -1441,13 +1459,31 @@ class Vits(BaseTTS):
|
||||||
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.")
|
||||||
|
# emotion embedding
|
||||||
|
if self.args.use_emotion_embedding and ref_emotion is not None and target_emotion is not None:
|
||||||
|
ge_src = self.emb_g(ref_emotion).unsqueeze(-1)
|
||||||
|
ge_tgt = self.emb_g(target_emotion).unsqueeze(-1)
|
||||||
|
elif self.args.use_external_emotions_embeddings and ref_emotion is not None and target_emotion is not None:
|
||||||
|
ge_src = F.normalize(ref_emotion).unsqueeze(-1)
|
||||||
|
ge_tgt = F.normalize(target_emotion).unsqueeze(-1)
|
||||||
|
|
||||||
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src)
|
z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src)
|
||||||
z_p = self.flow(z, y_mask, g=g_src)
|
z_p = self.flow(z, y_mask, g=g_src)
|
||||||
|
|
||||||
|
# change the emotion
|
||||||
|
if ge_tgt is not None and ge_tgt is not None and self.args.use_prosody_conditional_flow_module:
|
||||||
|
if not self.args.prosody_conditional_flow_module_on_decoder:
|
||||||
|
ze = self.prosody_conditional_module(z_p, y_mask, g=ge_src, reverse=True)
|
||||||
|
z_p = self.prosody_conditional_module(ze, y_mask, g=ge_tgt)
|
||||||
|
else:
|
||||||
|
ze = self.prosody_conditional_module(z_p, y_mask, g=ge_src)
|
||||||
|
z_p = self.prosody_conditional_module(ze, y_mask, g=ge_tgt, reverse=True)
|
||||||
|
|
||||||
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
||||||
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt)
|
||||||
return o_hat, y_mask, (z, z_p, z_hat)
|
return o_hat, y_mask, (z, z_p, z_hat)
|
||||||
|
|
||||||
|
|
||||||
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
|
||||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||||
|
|
||||||
|
|
|
@ -266,6 +266,8 @@ def transfer_voice(
|
||||||
reference_d_vector=None,
|
reference_d_vector=None,
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
use_griffin_lim=False,
|
use_griffin_lim=False,
|
||||||
|
source_emotion_feature=None,
|
||||||
|
target_emotion_feature=None,
|
||||||
):
|
):
|
||||||
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
"""Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to
|
||||||
the vocoder model.
|
the vocoder model.
|
||||||
|
@ -311,6 +313,12 @@ def transfer_voice(
|
||||||
if reference_d_vector is not None:
|
if reference_d_vector is not None:
|
||||||
reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda)
|
reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda)
|
||||||
|
|
||||||
|
if source_emotion_feature is not None:
|
||||||
|
source_emotion_feature = embedding_to_torch(source_emotion_feature, cuda=use_cuda)
|
||||||
|
|
||||||
|
if target_emotion_feature is not None:
|
||||||
|
target_emotion_feature = embedding_to_torch(target_emotion_feature, cuda=use_cuda)
|
||||||
|
|
||||||
# load reference_wav audio
|
# load reference_wav audio
|
||||||
reference_wav = embedding_to_torch(model.ap.load_wav(reference_wav, sr=model.ap.sample_rate), cuda=use_cuda)
|
reference_wav = embedding_to_torch(model.ap.load_wav(reference_wav, sr=model.ap.sample_rate), cuda=use_cuda)
|
||||||
|
|
||||||
|
@ -318,7 +326,7 @@ def transfer_voice(
|
||||||
_func = model.module.inference_voice_conversion
|
_func = model.module.inference_voice_conversion
|
||||||
else:
|
else:
|
||||||
_func = model.inference_voice_conversion
|
_func = model.inference_voice_conversion
|
||||||
model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector)
|
model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector, ref_emotion=source_emotion_feature, target_emotion=target_emotion_feature)
|
||||||
|
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
# plot results
|
# plot results
|
||||||
|
|
|
@ -214,6 +214,8 @@ class Synthesizer(object):
|
||||||
reference_wav=None,
|
reference_wav=None,
|
||||||
reference_speaker_name=None,
|
reference_speaker_name=None,
|
||||||
emotion_name=None,
|
emotion_name=None,
|
||||||
|
source_emotion=None,
|
||||||
|
target_emotion=None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
|
||||||
|
@ -293,7 +295,7 @@ class Synthesizer(object):
|
||||||
|
|
||||||
# handle emotion
|
# handle emotion
|
||||||
emotion_embedding, emotion_id = None, None
|
emotion_embedding, emotion_id = None, None
|
||||||
if not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
|
if not reference_wav and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
|
||||||
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||||
)):
|
)):
|
||||||
if emotion_name and isinstance(emotion_name, str):
|
if emotion_name and isinstance(emotion_name, str):
|
||||||
|
@ -398,6 +400,32 @@ class Synthesizer(object):
|
||||||
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
|
reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(
|
||||||
reference_wav
|
reference_wav
|
||||||
)
|
)
|
||||||
|
# get the emotions embeddings
|
||||||
|
# handle emotion
|
||||||
|
source_emotion_feature, target_emotion_feature = None, None
|
||||||
|
if source_emotion is not None and target_emotion is not None and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or (
|
||||||
|
getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None)
|
||||||
|
)):
|
||||||
|
if source_emotion and isinstance(source_emotion, str):
|
||||||
|
if getattr(self.tts_config, "use_external_emotions_embeddings", False) or (
|
||||||
|
getattr(self.tts_config, "model_args", None)
|
||||||
|
and getattr(self.tts_config.model_args, "use_external_emotions_embeddings", False)
|
||||||
|
):
|
||||||
|
# get the average emotion embedding from the saved embeddings.
|
||||||
|
source_emotion_feature = self.tts_model.emotion_manager.get_mean_embedding(
|
||||||
|
source_emotion, num_samples=None, randomize=False
|
||||||
|
)
|
||||||
|
source_emotion_feature = np.array(source_emotion_feature)[None, :] # [1 x embedding_dim]
|
||||||
|
# target
|
||||||
|
target_emotion_feature = self.tts_model.emotion_manager.get_mean_embedding(
|
||||||
|
target_emotion, num_samples=None, randomize=False
|
||||||
|
)
|
||||||
|
target_emotion_feature = np.array(target_emotion_feature)[None, :] # [1 x embedding_dim]
|
||||||
|
else:
|
||||||
|
# get emotion idx
|
||||||
|
source_emotion_feature = self.tts_model.emotion_manager.ids[source_emotion]
|
||||||
|
target_emotion_feature = self.tts_model.emotion_manager.ids[target_emotion]
|
||||||
|
|
||||||
|
|
||||||
outputs = transfer_voice(
|
outputs = transfer_voice(
|
||||||
model=self.tts_model,
|
model=self.tts_model,
|
||||||
|
@ -409,6 +437,8 @@ class Synthesizer(object):
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
reference_speaker_id=reference_speaker_id,
|
reference_speaker_id=reference_speaker_id,
|
||||||
reference_d_vector=reference_speaker_embedding,
|
reference_d_vector=reference_speaker_embedding,
|
||||||
|
source_emotion_feature=source_emotion_feature,
|
||||||
|
target_emotion_feature=target_emotion_feature,
|
||||||
)
|
)
|
||||||
waveform = outputs
|
waveform = outputs
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
|
|
|
@ -49,6 +49,9 @@ config.model_args.use_text_enc_spk_reversal_classifier = False
|
||||||
|
|
||||||
|
|
||||||
config.model_args.use_prosody_conditional_flow_module = True
|
config.model_args.use_prosody_conditional_flow_module = True
|
||||||
|
config.model_args.prosody_conditional_flow_module_on_decoder = True
|
||||||
|
config.model_args.use_text_enc_emo_classifier = True
|
||||||
|
|
||||||
# consistency loss
|
# consistency loss
|
||||||
# config.model_args.use_emotion_encoder_as_loss = True
|
# config.model_args.use_emotion_encoder_as_loss = True
|
||||||
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar"
|
||||||
|
|
Loading…
Reference in New Issue