From 39575f293711648cec9276bad6a2ef6ce3e74702 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 16 Mar 2022 20:57:14 +0000 Subject: [PATCH] Bug fix in single speaker emotion embedding training --- TTS/bin/compute_embeddings.py | 4 ++-- TTS/tts/models/vits.py | 19 +++++++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 0a4c6e29..bef7b384 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -7,7 +7,7 @@ from tqdm import tqdm from TTS.config import load_config from TTS.tts.datasets import load_tts_samples -from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.managers import EmbeddingManager parser = argparse.ArgumentParser( description="""Compute embedding vectors for each wav file in a dataset.\n\n""" @@ -48,7 +48,7 @@ if meta_data_eval is None: else: wav_files = meta_data_train + meta_data_eval -encoder_manager = SpeakerManager( +encoder_manager = EmbeddingManager( encoder_model_path=args.model_path, encoder_config_path=args.config_path, d_vectors_file_path=args.old_file, diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index df010de6..668400fd 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1030,7 +1030,10 @@ class Vits(BaseTTS): # concat the emotion embedding and speaker embedding if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): - g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] + if g is None: + g = eg + else: + g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] # language embedding lang_emb = None @@ -1145,8 +1148,11 @@ class Vits(BaseTTS): eg = self.emb_emotion(eid).unsqueeze(-1) # [b, h, 1] # concat the emotion embedding and speaker embedding - if eg is not None and g is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): - g = torch.cat([g, eg], dim=1) # [b, h1+h1, 1] + if eg is not None and (self.args.use_emotion_embedding or self.args.use_external_emotions_embeddings): + if g is None: + g = eg + else: + g = torch.cat([g, eg], dim=1) # [b, h1+h2, 1] # language embedding lang_emb = None @@ -1780,10 +1786,15 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) emotion_manager = EmotionManager.init_from_config(config) - if config.model_args.encoder_model_path: + if config.model_args.encoder_model_path and speaker_manager is not None: speaker_manager.init_encoder( config.model_args.encoder_model_path, config.model_args.encoder_config_path ) + elif config.model_args.encoder_model_path and emotion_manager is not None: + emotion_manager.init_encoder( + config.model_args.encoder_model_path, config.model_args.encoder_config_path + ) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager, emotion_manager=emotion_manager)