get_aux_input

This commit is contained in:
Julian WEBER 2021-10-27 12:02:02 +02:00 committed by Eren Gölge
parent 5c89803968
commit 3440c54bbe
1 changed files with 62 additions and 2 deletions

View File

@ -387,6 +387,25 @@ class Vits(BaseTTS):
if config.use_d_vector_file:
self._init_d_vector(config)
# TODO: make this a function
if config.use_speaker_encoder_as_loss:
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
for param in self.speaker_encoder.parameters():
param.requires_grad = False
print(" > External Speaker Encoder Loaded !!")
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
else:
self.audio_transform = None
else:
self.audio_transform = None
self.speaker_encoder = None
def _init_speaker_embedding(self, config):
# pylint: disable=attribute-defined-outside-init
if config.speakers_file is not None:
@ -469,8 +488,49 @@ class Vits(BaseTTS):
return sid, g, lid
def get_aux_input(self, aux_input: Dict):
sid, g = self._set_cond_input(aux_input)
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
sid, g, lid = self._set_cond_input(aux_input)
return {"speaker_id": sid, "style_wav": None, "d_vector": g, "language_id": lid}
def get_aux_input_from_test_setences(self, sentence_info):
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
config = self.config
# extract speaker and language info
text, speaker_name, style_wav, language_name = None, None, None, None
if isinstance(sentence_info, list):
if len(sentence_info) == 1:
text = sentence_info[0]
elif len(sentence_info) == 2:
text, speaker_name = sentence_info
elif len(sentence_info) == 3:
text, speaker_name, style_wav = sentence_info
elif len(sentence_info) == 4:
text, speaker_name, style_wav, language_name = sentence_info
else:
text = sentence_info
# get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_d_vector()
else:
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
else:
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
def forward(
self,