From 4b59f079462398a40650236da2ab986ab187714f Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Tue, 7 Jun 2022 13:29:22 -0300 Subject: [PATCH] Support the use of speaker embedding as emotion embedding --- TTS/bin/compute_embeddings.py | 3 +- TTS/bin/synthesize.py | 7 +- TTS/tts/layers/generic/classifier.py | 4 +- TTS/tts/layers/losses.py | 11 +- TTS/tts/layers/vits/discriminator.py | 4 +- TTS/tts/layers/vits/prosody_encoder.py | 4 +- TTS/tts/models/vits.py | 155 ++++++++++++------ TTS/tts/utils/emotions.py | 12 +- TTS/tts/utils/synthesis.py | 1 - TTS/utils/synthesizer.py | 16 +- .../test_glow_tts_speaker_emb_train.py | 3 + ...est_vits_speaker_emb_with_emotion_train.py | 13 +- ..._using_speaker_embedding_as_emotion_emb.py | 90 ++++++++++ 13 files changed, 244 insertions(+), 79 deletions(-) create mode 100644 tests/tts_tests/test_vits_using_speaker_embedding_as_emotion_emb.py diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 7364deee..a9b0ab2d 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -7,8 +7,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets import load_tts_samples -from TTS.tts.utils.managers import save_file -from TTS.tts.utils.managers import EmbeddingManager +from TTS.tts.utils.managers import EmbeddingManager, save_file parser = argparse.ArgumentParser( description="""Compute embedding vectors for each wav file in a dataset.\n\n""" diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index ee44731a..570e49ea 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -179,7 +179,12 @@ If you don't specify any models, then it uses LJSpeech based English model. default=None, ) parser.add_argument("--style_wav", type=str, help="Wav path file for prosody reference.", default=None) - parser.add_argument("--style_speaker_name", type=str, help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", default=None) + parser.add_argument( + "--style_speaker_name", + type=str, + help="The speaker name from the style_wav. If not provide the speaker embedding will be computed using the speaker encoder.", + default=None, + ) parser.add_argument( "--capacitron_style_text", type=str, help="Transcription of the style_wav reference.", default=None ) diff --git a/TTS/tts/layers/generic/classifier.py b/TTS/tts/layers/generic/classifier.py index 1cd60006..33283721 100644 --- a/TTS/tts/layers/generic/classifier.py +++ b/TTS/tts/layers/generic/classifier.py @@ -32,7 +32,9 @@ class ReversalClassifier(nn.Module): reversal (bool): If True reversal the gradients. Default: True """ - def __init__(self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True): + def __init__( + self, in_channels, out_channels, hidden_channels, gradient_clipping_bounds=0.25, scale_factor=1.0, reversal=True + ): super().__init__() self.reversal = reversal self._lambda = scale_factor diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 0d8fbe0b..0d171c1c 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -9,13 +9,8 @@ from torch.nn import functional from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.ssim import ssim from TTS.utils.audio import TorchSTFT -from TTS.vocoder.layers.losses import ( - MelganFeatureLoss, - MSEDLoss, - MSEGLoss, - _apply_D_loss, - _apply_G_adv_loss, -) +from TTS.vocoder.layers.losses import MelganFeatureLoss, MSEDLoss, MSEGLoss, _apply_D_loss, _apply_G_adv_loss + # pylint: disable=abstract-method # relates https://github.com/pytorch/pytorch/issues/42305 @@ -730,7 +725,6 @@ class VitsGeneratorLoss(nn.Module): loss += loss_prosody_enc_emo_classifier return_dict["loss_prosody_enc_emo_classifier"] = loss_prosody_enc_emo_classifier - if loss_text_enc_spk_rev_classifier is not None: loss_text_enc_spk_rev_classifier = loss_text_enc_spk_rev_classifier * self.speaker_classifier_alpha loss += loss_text_enc_spk_rev_classifier @@ -779,7 +773,6 @@ class VitsDiscriminatorLoss(nn.Module): self.disc_latent_loss_alpha = c.disc_latent_loss_alpha self.disc_latent_gan_loss = MSEDLoss() - @staticmethod def discriminator_loss(scores_real, scores_fake): loss = 0 diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index 8ec67d1e..a760d54e 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -109,7 +109,9 @@ class LatentDiscriminator(nn.Module): norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm self.discriminators = nn.ModuleList( [ - norm_f(nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f( + nn.Conv2d(1 if hidden_channels is None else hidden_channels, 32, kernel_size=(3, 9), padding=(1, 4)) + ), norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), diff --git a/TTS/tts/layers/vits/prosody_encoder.py b/TTS/tts/layers/vits/prosody_encoder.py index ea8d11f6..27571da0 100644 --- a/TTS/tts/layers/vits/prosody_encoder.py +++ b/TTS/tts/layers/vits/prosody_encoder.py @@ -1,5 +1,6 @@ -from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST + class VitsGST(GST): def __init__(self, *args, **kwargs): @@ -9,6 +10,7 @@ class VitsGST(GST): style_embed = super().forward(inputs, speaker_embedding=speaker_embedding) return style_embed, None + class VitsVAE(CapacitronVAE): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 4ccd4694..5be47885 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -19,12 +19,11 @@ from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset, _parse_sample from TTS.tts.layers.generic.classifier import ReversalClassifier from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor -from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder -from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.layers.vits.prosody_encoder import VitsGST, VitsVAE from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor - from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.emotions import EmotionManager from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask @@ -37,7 +36,6 @@ from TTS.tts.utils.visual import plot_alignment from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results - ############################## # IO / Feature extraction ############################## @@ -684,7 +682,11 @@ class Vits(BaseTTS): dp_cond_embedding_dim += self.args.prosody_embedding_dim dp_extra_inp_dim = 0 - if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings or self.args.use_speaker_embedding_as_emotion) and not self.args.use_noise_scale_predictor: + if ( + self.args.use_emotion_embedding + or self.args.use_external_emotions_embeddings + or self.args.use_speaker_embedding_as_emotion + ) and not self.args.use_noise_scale_predictor: dp_extra_inp_dim += self.args.emotion_embedding_dim if self.args.use_prosody_encoder and not self.args.use_noise_scale_predictor: @@ -711,22 +713,22 @@ class Vits(BaseTTS): ) if self.args.use_prosody_encoder: - if self.args.prosody_encoder_type == 'gst': + if self.args.prosody_encoder_type == "gst": self.prosody_encoder = VitsGST( num_mel=self.args.hidden_channels, num_heads=self.args.prosody_encoder_num_heads, num_style_tokens=self.args.prosody_encoder_num_tokens, gst_embedding_dim=self.args.prosody_embedding_dim, ) - elif self.args.prosody_encoder_type == 'vae': + elif self.args.prosody_encoder_type == "vae": self.prosody_encoder = VitsVAE( num_mel=self.args.hidden_channels, capacitron_VAE_embedding_dim=self.args.prosody_embedding_dim, ) else: raise RuntimeError( - f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" - ) + f" [!] The Prosody encoder type {self.args.prosody_encoder_type} is not supported !!" + ) if self.args.use_prosody_enc_spk_reversal_classifier: self.speaker_reversal_classifier = ReversalClassifier( in_channels=self.args.prosody_embedding_dim, @@ -738,12 +740,16 @@ class Vits(BaseTTS): in_channels=self.args.prosody_embedding_dim, out_channels=self.num_emotions, hidden_channels=256, - reversal=False + reversal=False, ) if self.args.use_noise_scale_predictor: noise_scale_predictor_input_dim = self.args.hidden_channels - if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: + if ( + self.args.use_emotion_embedding + or self.args.use_external_emotions_embeddings + or self.args.use_speaker_embedding_as_emotion + ): noise_scale_predictor_input_dim += self.args.emotion_embedding_dim if self.args.use_prosody_encoder: @@ -763,15 +769,18 @@ class Vits(BaseTTS): ) if self.args.use_emotion_embedding_squeezer: - self.emotion_embedding_squeezer = nn.Linear(in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim) + self.emotion_embedding_squeezer = nn.Linear( + in_features=self.args.emotion_embedding_squeezer_input_dim, out_features=self.args.emotion_embedding_dim + ) if self.args.use_speaker_embedding_squeezer: - self.speaker_embedding_squeezer = nn.Linear(in_features=self.args.speaker_embedding_squeezer_input_dim, out_features=self.cond_embedding_dim) + self.speaker_embedding_squeezer = nn.Linear( + in_features=self.args.speaker_embedding_squeezer_input_dim, out_features=self.cond_embedding_dim + ) if self.args.use_text_enc_spk_reversal_classifier: self.speaker_text_enc_reversal_classifier = ReversalClassifier( - in_channels=self.args.hidden_channels - + dp_extra_inp_dim, + in_channels=self.args.hidden_channels + dp_extra_inp_dim, out_channels=self.num_speakers, hidden_channels=256, ) @@ -781,7 +790,7 @@ class Vits(BaseTTS): in_channels=self.args.hidden_channels, out_channels=self.num_emotions, hidden_channels=256, - reversal=False + reversal=False, ) self.waveform_decoder = HifiganGenerator( @@ -1176,9 +1185,16 @@ class Vits(BaseTTS): if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) + if self.args.use_speaker_embedding_as_emotion: + eg = g + # squeezers if self.args.use_emotion_embedding_squeezer: - if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: + if ( + self.args.use_emotion_embedding + or self.args.use_external_emotions_embeddings + or self.args.use_speaker_embedding_as_emotion + ): eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1) if self.args.use_speaker_embedding_squeezer: @@ -1200,7 +1216,7 @@ class Vits(BaseTTS): prosody_encoder_input = z_p if self.args.use_prosody_encoder_z_p_input else z pros_emb, vae_outputs = self.prosody_encoder( prosody_encoder_input.detach() if self.args.detach_prosody_enc_input else prosody_encoder_input, - y_lengths + y_lengths, ) pros_emb = pros_emb.transpose(1, 2) @@ -1215,7 +1231,7 @@ class Vits(BaseTTS): x_lengths, lang_emb=lang_emb, emo_emb=eg if not self.args.use_noise_scale_predictor else None, - pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None + pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None, ) # reversal speaker loss to force the encoder to be speaker identity free @@ -1250,10 +1266,14 @@ class Vits(BaseTTS): if self.args.use_noise_scale_predictor: nsp_input = torch.transpose(m_p_expanded, 1, -1) if self.args.use_prosody_encoder and pros_emb is not None: - nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1) + nsp_input = torch.cat( + (nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1 + ) if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None: - nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1) + nsp_input = torch.cat( + (nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1 + ) nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask) @@ -1314,7 +1334,7 @@ class Vits(BaseTTS): "loss_prosody_enc_spk_rev_classifier": l_pros_speaker, "loss_prosody_enc_emo_classifier": l_pros_emotion, "loss_text_enc_spk_rev_classifier": l_text_speaker, - "loss_text_enc_emo_classifier": l_text_emotion + "loss_text_enc_emo_classifier": l_text_emotion, } ) return outputs @@ -1373,9 +1393,16 @@ class Vits(BaseTTS): if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) + if self.args.use_speaker_embedding_as_emotion: + eg = g + # squeezers if self.args.use_emotion_embedding_squeezer: - if self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings: + if ( + self.args.use_emotion_embedding + or self.args.use_external_emotions_embeddings + or self.args.use_speaker_embedding_as_emotion + ): eg = F.normalize(self.emotion_embedding_squeezer(eg.squeeze(-1))).unsqueeze(-1) if self.args.use_speaker_embedding_squeezer: @@ -1399,13 +1426,12 @@ class Vits(BaseTTS): pros_emb, _ = self.prosody_encoder(z_p_inf, pf_lengths) pros_emb = pros_emb.transpose(1, 2) - x, m_p, logs_p, x_mask = self.text_encoder( x, x_lengths, lang_emb=lang_emb, emo_emb=eg if not self.args.use_noise_scale_predictor else None, - pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None + pros_emb=pros_emb if not self.args.use_noise_scale_predictor else None, ) # duration predictor @@ -1448,10 +1474,14 @@ class Vits(BaseTTS): if self.args.use_noise_scale_predictor: nsp_input = torch.transpose(m_p, 1, -1) if self.args.use_prosody_encoder and pros_emb is not None: - nsp_input = torch.cat((nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1) + nsp_input = torch.cat( + (nsp_input, pros_emb.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1 + ) if (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings) and eg is not None: - nsp_input = torch.cat((nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1) + nsp_input = torch.cat( + (nsp_input, eg.transpose(2, 1).expand(nsp_input.size(0), nsp_input.size(1), -1)), dim=-1 + ) nsp_input = torch.transpose(nsp_input, 1, -1) * y_mask m_p_noise_scale = self.noise_scale_predictor(nsp_input, y_mask) @@ -1521,7 +1551,6 @@ class Vits(BaseTTS): wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) return wav - def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): """Forward pass for voice conversion @@ -1550,7 +1579,6 @@ class Vits(BaseTTS): o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) return o_hat, y_mask, (z, z_p, z_hat) - def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: """Perform a single training step. Run the model forward pass and compute losses. @@ -1599,18 +1627,16 @@ class Vits(BaseTTS): self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init # compute scores and features - scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _= self.disc( - outputs["model_outputs"].detach(), outputs["waveform_seg"], outputs["m_p"].detach(), outputs["z_p"].detach() + scores_disc_fake, _, scores_disc_real, _, scores_disc_mp, _, scores_disc_zp, _ = self.disc( + outputs["model_outputs"].detach(), + outputs["waveform_seg"], + outputs["m_p"].detach(), + outputs["z_p"].detach(), ) # compute loss with autocast(enabled=False): # use float32 for the criterion - loss_dict = criterion[optimizer_idx]( - scores_disc_real, - scores_disc_fake, - scores_disc_zp, - scores_disc_mp - ) + loss_dict = criterion[optimizer_idx](scores_disc_real, scores_disc_fake, scores_disc_zp, scores_disc_mp) return outputs, loss_dict if optimizer_idx == 1: @@ -1640,8 +1666,20 @@ class Vits(BaseTTS): ) # compute discriminator scores and features - scores_disc_fake, feats_disc_fake, _, feats_disc_real, scores_disc_mp, feats_disc_mp, _, feats_disc_zp = self.disc( - self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"], self.model_outputs_cache["m_p"], self.model_outputs_cache["z_p"].detach() + ( + scores_disc_fake, + feats_disc_fake, + _, + feats_disc_real, + scores_disc_mp, + feats_disc_mp, + _, + feats_disc_zp, + ) = self.disc( + self.model_outputs_cache["model_outputs"], + self.model_outputs_cache["waveform_seg"], + self.model_outputs_cache["m_p"], + self.model_outputs_cache["z_p"].detach(), ) # compute losses @@ -1669,7 +1707,7 @@ class Vits(BaseTTS): loss_text_enc_emo_classifier=self.model_outputs_cache["loss_text_enc_emo_classifier"], scores_disc_mp=scores_disc_mp, feats_disc_mp=feats_disc_mp, - feats_disc_zp=feats_disc_zp + feats_disc_zp=feats_disc_zp, ) return self.model_outputs_cache, loss_dict @@ -1729,7 +1767,14 @@ class Vits(BaseTTS): config = self.config # extract speaker and language info - text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = None, None, None, None, None, None + text, speaker_name, style_wav, language_name, emotion_name, style_speaker_name = ( + None, + None, + None, + None, + None, + None, + ) if isinstance(sentence_info, list): if len(sentence_info) == 1: @@ -1748,12 +1793,18 @@ class Vits(BaseTTS): text = sentence_info if style_wav and style_speaker_name is None: - raise RuntimeError( - " [!] You must to provide the style_speaker_name for the style_wav !!" - ) + raise RuntimeError(" [!] You must to provide the style_speaker_name for the style_wav !!") # get speaker id/d_vector - speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = None, None, None, None, None, None, None + speaker_id, d_vector, language_id, emotion_id, emotion_embedding, style_speaker_id, style_speaker_d_vector = ( + None, + None, + None, + None, + None, + None, + None, + ) if hasattr(self, "speaker_manager"): if config.use_d_vector_file: if speaker_name is None: @@ -1762,7 +1813,9 @@ class Vits(BaseTTS): d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) if style_wav is not None: - style_speaker_d_vector = self.speaker_manager.get_mean_embedding(style_speaker_name, num_samples=None, randomize=False) + style_speaker_d_vector = self.speaker_manager.get_mean_embedding( + style_speaker_name, num_samples=None, randomize=False + ) elif config.use_speaker_embedding: if speaker_name is None: @@ -1893,7 +1946,15 @@ class Vits(BaseTTS): emotion_embeddings = [emotion_mapping[w]["embedding"] for w in batch["audio_files"]] emotion_embeddings = torch.FloatTensor(emotion_embeddings) - if self.emotion_manager is not None and self.emotion_manager.embeddings and (self.args.use_emotion_embedding or self.args.use_prosody_enc_emo_classifier or self.args.use_text_enc_emo_classifier): + if ( + self.emotion_manager is not None + and self.emotion_manager.embeddings + and ( + self.args.use_emotion_embedding + or self.args.use_prosody_enc_emo_classifier + or self.args.use_text_enc_emo_classifier + ) + ): emotion_mapping = self.emotion_manager.embeddings emotion_names = [emotion_mapping[w]["name"] for w in batch["audio_files"]] emotion_ids = [self.emotion_manager.ids[en] for en in emotion_names] diff --git a/TTS/tts/utils/emotions.py b/TTS/tts/utils/emotions.py index 57cd8060..80c01e12 100644 --- a/TTS/tts/utils/emotions.py +++ b/TTS/tts/utils/emotions.py @@ -94,7 +94,11 @@ class EmotionManager(EmbeddingManager): EmotionEncoder: Emotion encoder object. """ emotion_manager = None - if get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False): + if ( + get_from_config_or_model_args_with_default(config, "use_emotion_embedding", False) + or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) + or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False) + ): if get_from_config_or_model_args_with_default(config, "emotions_ids_file", None): emotion_manager = EmotionManager( emotion_id_file_path=get_from_config_or_model_args_with_default(config, "emotions_ids_file", None) @@ -106,7 +110,11 @@ class EmotionManager(EmbeddingManager): ) ) - if get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False): + if ( + get_from_config_or_model_args_with_default(config, "use_external_emotions_embeddings", False) + or get_from_config_or_model_args_with_default(config, "use_prosody_enc_emo_classifier", False) + or get_from_config_or_model_args_with_default(config, "use_text_enc_emo_classifier", False) + ): if get_from_config_or_model_args_with_default(config, "external_emotions_embs_file", None): emotion_manager = EmotionManager( embeddings_file_path=get_from_config_or_model_args_with_default( diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index b625b39e..4a9141d0 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -217,7 +217,6 @@ def synthesis( if style_speaker_d_vector is not None: style_speaker_d_vector = embedding_to_torch(style_speaker_d_vector, cuda=use_cuda) - if not isinstance(style_feature, dict): # GST or Capacitron style mel style_feature = numpy_to_torch(style_feature, torch.float, cuda=use_cuda) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 5ff7e41f..e9a79220 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -306,9 +306,17 @@ class Synthesizer(object): # handle emotion emotion_embedding, emotion_id = None, None - if not reference_wav and not getattr(self.tts_model, "prosody_encoder", False) and (self.tts_emotions_file or ( - getattr(self.tts_model, "emotion_manager", None) and getattr(self.tts_model.emotion_manager, "ids", None) - )): + if ( + not reference_wav + and not getattr(self.tts_model, "prosody_encoder", False) + and ( + self.tts_emotions_file + or ( + getattr(self.tts_model, "emotion_manager", None) + and getattr(self.tts_model.emotion_manager, "ids", None) + ) + ) + ): if emotion_name and isinstance(emotion_name, str): if getattr(self.tts_config, "use_external_emotions_embeddings", False) or ( getattr(self.tts_config, "model_args", None) @@ -426,7 +434,7 @@ class Synthesizer(object): d_vector=speaker_embedding, use_griffin_lim=use_gl, reference_speaker_id=reference_speaker_id, - reference_d_vector=reference_speaker_embedding + reference_d_vector=reference_speaker_embedding, ) waveform = outputs if not use_gl: diff --git a/tests/tts_tests/test_glow_tts_speaker_emb_train.py b/tests/tts_tests/test_glow_tts_speaker_emb_train.py index 322b506e..81a17e7a 100644 --- a/tests/tts_tests/test_glow_tts_speaker_emb_train.py +++ b/tests/tts_tests/test_glow_tts_speaker_emb_train.py @@ -58,6 +58,9 @@ continue_restore_path, _ = get_last_checkpoint(continue_path) out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") +if not os.path.isfile(continue_speakers_path): + continue_speakers_path = continue_speakers_path.replace(".json", ".pth") + # Check integrity of the config with open(continue_config_path, "r", encoding="utf-8") as f: diff --git a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py index 28b8f203..bef67ee5 100644 --- a/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_with_emotion_train.py @@ -37,23 +37,16 @@ config.audio.trim_db = 60 config.model_args.use_speaker_embedding = False config.model_args.use_d_vector_file = True config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" -config.model_args.speaker_embedding_channels = 128 -config.model_args.d_vector_dim = 100 +config.model_args.speaker_embedding_channels = 256 +config.model_args.d_vector_dim = 256 # emotion config.model_args.use_external_emotions_embeddings = True config.model_args.use_emotion_embedding = False -config.model_args.emotion_embedding_dim = 64 +config.model_args.emotion_embedding_dim = 256 config.model_args.external_emotions_embs_file = "tests/data/ljspeech/speakers.json" config.model_args.use_text_enc_spk_reversal_classifier = False - -config.model_args.use_emotion_embedding_squeezer = True -config.model_args.emotion_embedding_squeezer_input_dim = 256 - -config.model_args.use_speaker_embedding_squeezer = True -config.model_args.speaker_embedding_squeezer_input_dim = 256 - # consistency loss # config.model_args.use_emotion_encoder_as_loss = True # config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar" diff --git a/tests/tts_tests/test_vits_using_speaker_embedding_as_emotion_emb.py b/tests/tts_tests/test_vits_using_speaker_embedding_as_emotion_emb.py new file mode 100644 index 00000000..ac157c58 --- /dev/null +++ b/tests/tts_tests/test_vits_using_speaker_embedding_as_emotion_emb.py @@ -0,0 +1,90 @@ +import glob +import os +import shutil + +from trainer import get_last_checkpoint + +from tests import get_device_id, get_tests_output_path, run_cli +from TTS.tts.configs.vits_config import VitsConfig + +config_path = os.path.join(get_tests_output_path(), "test_model_config.json") +output_path = os.path.join(get_tests_output_path(), "train_outputs") + + +config = VitsConfig( + batch_size=2, + eval_batch_size=2, + num_loader_workers=0, + num_eval_loader_workers=0, + text_cleaner="english_cleaners", + use_phonemes=True, + phoneme_language="en-us", + phoneme_cache_path="tests/data/ljspeech/phoneme_cache/", + run_eval=True, + test_delay_epochs=-1, + epochs=1, + print_step=1, + print_eval=True, + test_sentences=[ + ["Be a voice, not an echo.", "ljspeech-1", None, None, "ljspeech-1"], + ], +) +# set audio config +config.audio.do_trim_silence = True +config.audio.trim_db = 60 + +# active multispeaker d-vec mode +config.model_args.use_speaker_embedding = False +config.model_args.use_d_vector_file = True +config.model_args.d_vector_file = "tests/data/ljspeech/speakers.json" +config.model_args.speaker_embedding_channels = 128 +config.model_args.d_vector_dim = 256 + +# emotion +config.model_args.emotion_embedding_dim = 256 + +config.model_args.use_emotion_embedding_squeezer = False +config.model_args.emotion_embedding_squeezer_input_dim = 256 +config.model_args.use_speaker_embedding_as_emotion = True + +config.model_args.use_speaker_embedding_squeezer = False +config.model_args.speaker_embedding_squeezer_input_dim = 256 + +# consistency loss +# config.model_args.use_emotion_encoder_as_loss = True +# config.model_args.encoder_model_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/model_se.pth.tar" +# config.model_args.encoder_config_path = "/raid/edresson/dev/Checkpoints/Coqui-Realesead/tts_models--multilingual--multi-dataset--your_tts/config_se.json" + +config.save_json(config_path) + +# train the model for one epoch +command_train = ( + f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} " + f"--coqpit.output_path {output_path} " + "--coqpit.datasets.0.name ljspeech_test " + "--coqpit.datasets.0.meta_file_train metadata.csv " + "--coqpit.datasets.0.meta_file_val metadata.csv " + "--coqpit.datasets.0.path tests/data/ljspeech " + "--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt " + "--coqpit.test_delay_epochs 0" +) +run_cli(command_train) + +# Find latest folder +continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) + +# Inference using TTS API +continue_config_path = os.path.join(continue_path, "config.json") +continue_restore_path, _ = get_last_checkpoint(continue_path) +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") +speaker_id = "ljspeech-1" +continue_speakers_path = config.model_args.d_vector_file + + +inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" +run_cli(inference_command) + +# restore the model and continue training for one more epoch +command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} " +run_cli(command_train) +shutil.rmtree(continue_path)