mirror of https://github.com/coqui-ai/TTS.git
Fix synthesizer reading `use_language_embedding`
This commit is contained in:
parent
7a987db62b
commit
35a781fb90
|
@ -151,7 +151,10 @@ class Synthesizer(object):
|
||||||
"""Initialize the LanguageManager"""
|
"""Initialize the LanguageManager"""
|
||||||
# setup if multi-lingual settings are in the global model config
|
# setup if multi-lingual settings are in the global model config
|
||||||
language_manager = None
|
language_manager = None
|
||||||
if hasattr(self.tts_config, "use_language_embedding") and self.tts_config.use_language_embedding is True:
|
if (
|
||||||
|
hasattr(self.tts_config.model_args, "use_language_embedding")
|
||||||
|
and self.tts_config.model_args.use_language_embedding is True
|
||||||
|
):
|
||||||
if self.tts_languages_file:
|
if self.tts_languages_file:
|
||||||
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
|
language_manager = LanguageManager(language_ids_file_path=self.tts_languages_file)
|
||||||
elif self.tts_config.get("language_ids_file", None):
|
elif self.tts_config.get("language_ids_file", None):
|
||||||
|
@ -200,14 +203,14 @@ class Synthesizer(object):
|
||||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
self, text: str, speaker_idx: str = "", language_idx: str = "", speaker_wav=None, style_wav=None
|
self, text: str, speaker_name: str = "", language_name: str = "", speaker_wav=None, style_wav=None
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""🐸 TTS magic. Run all the models and generate speech.
|
"""🐸 TTS magic. Run all the models and generate speech.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): input text.
|
text (str): input text.
|
||||||
speaker_idx (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
speaker_name (str, optional): spekaer id for multi-speaker models. Defaults to "".
|
||||||
language_idx (str, optional): language id for multi-language models. Defaults to "".
|
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||||
speaker_wav ():
|
speaker_wav ():
|
||||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||||
|
|
||||||
|
@ -224,26 +227,26 @@ class Synthesizer(object):
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
speaker_id = None
|
speaker_id = None
|
||||||
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
|
if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"):
|
||||||
if speaker_idx and isinstance(speaker_idx, str):
|
if speaker_name and isinstance(speaker_name, str):
|
||||||
if self.tts_config.use_d_vector_file:
|
if self.tts_config.use_d_vector_file:
|
||||||
# get the speaker embedding from the saved d_vectors.
|
# get the speaker embedding from the saved d_vectors.
|
||||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_name)[0]
|
||||||
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim]
|
||||||
else:
|
else:
|
||||||
# get speaker idx from the speaker name
|
# get speaker idx from the speaker name
|
||||||
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_idx]
|
speaker_id = self.tts_model.speaker_manager.speaker_ids[speaker_name]
|
||||||
|
|
||||||
elif not speaker_idx and not speaker_wav:
|
elif not speaker_name and not speaker_wav:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
" [!] Look like you use a multi-speaker model. "
|
" [!] Look like you use a multi-speaker model. "
|
||||||
"You need to define either a `speaker_idx` or a `style_wav` to use a multi-speaker model."
|
"You need to define either a `speaker_name` or a `style_wav` to use a multi-speaker model."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
else:
|
else:
|
||||||
if speaker_idx:
|
if speaker_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f" [!] Missing speakers.json file path for selecting speaker {speaker_idx}."
|
f" [!] Missing speakers.json file path for selecting speaker {speaker_name}."
|
||||||
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -252,18 +255,18 @@ class Synthesizer(object):
|
||||||
if self.tts_languages_file or (
|
if self.tts_languages_file or (
|
||||||
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
hasattr(self.tts_model, "language_manager") and self.tts_model.language_manager is not None
|
||||||
):
|
):
|
||||||
if language_idx and isinstance(language_idx, str):
|
if language_name and isinstance(language_name, str):
|
||||||
language_id = self.tts_model.language_manager.language_id_mapping[language_idx]
|
language_id = self.tts_model.language_manager.language_id_mapping[language_name]
|
||||||
|
|
||||||
elif not language_idx:
|
elif not language_name:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
" [!] Look like you use a multi-lingual model. "
|
" [!] Look like you use a multi-lingual model. "
|
||||||
"You need to define either a `language_idx` or a `style_wav` to use a multi-lingual model."
|
"You need to define either a `language_name` or a `style_wav` to use a multi-lingual model."
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f" [!] Missing language_ids.json file path for selecting language {language_idx}."
|
f" [!] Missing language_ids.json file path for selecting language {language_name}."
|
||||||
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
|
"Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -283,7 +286,7 @@ class Synthesizer(object):
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
speaker_id=speaker_id,
|
speaker_id=speaker_id,
|
||||||
language_id=language_id,
|
language_id=language_id,
|
||||||
language_name=language_idx,
|
language_name=language_name,
|
||||||
style_wav=style_wav,
|
style_wav=style_wav,
|
||||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||||
use_griffin_lim=use_gl,
|
use_griffin_lim=use_gl,
|
||||||
|
|
Loading…
Reference in New Issue