mirror of https://github.com/coqui-ai/TTS.git
get_aux_input
This commit is contained in:
parent
b3abd01793
commit
9a2f91327c
|
@ -387,6 +387,25 @@ class Vits(BaseTTS):
|
|||
if config.use_d_vector_file:
|
||||
self._init_d_vector(config)
|
||||
|
||||
# TODO: make this a function
|
||||
if config.use_speaker_encoder_as_loss:
|
||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
||||
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
|
||||
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
|
||||
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
|
||||
for param in self.speaker_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
print(" > External Speaker Encoder Loaded !!")
|
||||
|
||||
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
|
||||
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
|
||||
else:
|
||||
self.audio_transform = None
|
||||
else:
|
||||
self.audio_transform = None
|
||||
self.speaker_encoder = None
|
||||
|
||||
def _init_speaker_embedding(self, config):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if config.speakers_file is not None:
|
||||
|
@ -469,8 +488,49 @@ class Vits(BaseTTS):
|
|||
return sid, g, lid
|
||||
|
||||
def get_aux_input(self, aux_input: Dict):
|
||||
sid, g = self._set_cond_input(aux_input)
|
||||
return {"speaker_id": sid, "style_wav": None, "d_vector": g}
|
||||
sid, g, lid = self._set_cond_input(aux_input)
|
||||
return {"speaker_id": sid, "style_wav": None, "d_vector": g, "language_id": lid}
|
||||
|
||||
def get_aux_input_from_test_setences(self, sentence_info):
|
||||
if hasattr(self.config, "model_args"):
|
||||
config = self.config.model_args
|
||||
else:
|
||||
config = self.config
|
||||
|
||||
# extract speaker and language info
|
||||
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||
|
||||
if isinstance(sentence_info, list):
|
||||
if len(sentence_info) == 1:
|
||||
text = sentence_info[0]
|
||||
elif len(sentence_info) == 2:
|
||||
text, speaker_name = sentence_info
|
||||
elif len(sentence_info) == 3:
|
||||
text, speaker_name, style_wav = sentence_info
|
||||
elif len(sentence_info) == 4:
|
||||
text, speaker_name, style_wav, language_name = sentence_info
|
||||
else:
|
||||
text = sentence_info
|
||||
|
||||
# get speaker id/d_vector
|
||||
speaker_id, d_vector, language_id = None, None, None
|
||||
if hasattr(self, "speaker_manager"):
|
||||
if config.use_d_vector_file:
|
||||
if speaker_name is None:
|
||||
d_vector = self.speaker_manager.get_random_d_vector()
|
||||
else:
|
||||
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
|
||||
elif config.use_speaker_embedding:
|
||||
if speaker_name is None:
|
||||
speaker_id = self.speaker_manager.get_random_speaker_id()
|
||||
else:
|
||||
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
|
||||
|
||||
# get language id
|
||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||
language_id = self.language_manager.language_id_mapping[language_name]
|
||||
|
||||
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue