diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index f72918a5..078d4973 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -387,6 +387,25 @@ class Vits(BaseTTS): if config.use_d_vector_file: self._init_d_vector(config) + # TODO: make this a function + if config.use_speaker_encoder_as_loss: + if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path: + raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!") + self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path) + self.speaker_encoder = self.speaker_manager.speaker_encoder.train() + for param in self.speaker_encoder.parameters(): + param.requires_grad = False + + print(" > External Speaker Encoder Loaded !!") + + if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]: + self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"]) + else: + self.audio_transform = None + else: + self.audio_transform = None + self.speaker_encoder = None + def _init_speaker_embedding(self, config): # pylint: disable=attribute-defined-outside-init if config.speakers_file is not None: @@ -469,8 +488,49 @@ class Vits(BaseTTS): return sid, g, lid def get_aux_input(self, aux_input: Dict): - sid, g = self._set_cond_input(aux_input) - return {"speaker_id": sid, "style_wav": None, "d_vector": g} + sid, g, lid = self._set_cond_input(aux_input) + return {"speaker_id": sid, "style_wav": None, "d_vector": g, "language_id": lid} + + def get_aux_input_from_test_setences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_d_vector() + else: + d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_speaker_id() + else: + speaker_id = self.speaker_manager.speaker_ids[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.language_id_mapping[language_name] + + return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id} def forward( self,