diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 8f58121a..84ca6111 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -22,6 +22,7 @@ class Synthesizer(object): self, tts_checkpoint: str, tts_config_path: str, + tts_speakers_file: str = "", vocoder_checkpoint: str = "", vocoder_config: str = "", use_cuda: bool = False, @@ -44,6 +45,7 @@ class Synthesizer(object): """ self.tts_checkpoint = tts_checkpoint self.tts_config_path = tts_config_path + self.tts_speakers_file = tts_speakers_file self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_config = vocoder_config self.use_cuda = use_cuda @@ -67,9 +69,9 @@ class Synthesizer(object): return pysbd.Segmenter(language=lang, clean=True) - def _load_speakers(self) -> None: + def _load_speakers(self, speaker_file: str) -> None: print("Loading speakers ...") - self.tts_speakers = load_speaker_mapping(self.tts_config.external_speaker_embedding_file) + self.tts_speakers = load_speaker_mapping(speaker_file) self.num_speakers = len(self.tts_speakers) self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][ "embedding" @@ -79,12 +81,12 @@ class Synthesizer(object): speaker_embedding = None - if self.tts_config.get("use_external_speaker_embedding_file") and not speaker_json_key: - raise ValueError("While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'") + if not speaker_json_key: + raise ValueError(" [!] While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'") if speaker_json_key != "": assert self.tts_speakers - assert speaker_json_key in self.tts_speakers, f"speaker_json_key is not in self.tts_speakers keys : '{speaker_idx}'" + assert speaker_json_key in self.tts_speakers, f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'" speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"] return speaker_embedding @@ -109,7 +111,7 @@ class Synthesizer(object): self.input_size = len(symbols) if self.tts_config.use_speaker_embedding is True: - self._load_speakers() + self._load_speakers(self.tts_config.get('external_speaker_embedding_file', self.tts_speakers_file)) self.tts_model = setup_model( self.input_size,