diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 236e78a9..9ecb5be9 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -87,52 +87,15 @@ class Synthesizer(object): """ return pysbd.Segmenter(language=lang, clean=True) - def _load_speakers(self, speaker_file: str) -> None: - """Load the SpeakerManager to organize multi-speaker TTS. It loads the speakers meta-data and the speaker - encoder if it is defined. - - Args: - speaker_file (str): path to the speakers meta-data file. - """ - print("Loading speakers ...") - self.speaker_manager = SpeakerManager( - encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config - ) - self.speaker_manager.load_d_vectors_file(self.tts_config.get("d_vector_file", speaker_file)) - self.num_speakers = self.speaker_manager.num_speakers - self.d_vector_dim = self.speaker_manager.d_vector_dim - - def _set_tts_speaker_file(self): - """Set the TTS speaker file used by a multi-speaker model.""" - # setup if multi-speaker settings are in the global model config - if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True: - if self.tts_config.use_d_vector_file: - self.tts_speakers_file = ( - self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"] - ) - self.tts_config["d_vector_file"] = self.tts_speakers_file - else: - self.tts_speakers_file = ( - self.tts_speakers_file if self.tts_speakers_file else self.tts_config["speakers_file"] - ) - - # setup if multi-speaker settings are in the model args config - if ( - self.tts_speakers_file is None - and hasattr(self.tts_config, "model_args") - and hasattr(self.tts_config.model_args, "use_speaker_embedding") - and self.tts_config.model_args.use_speaker_embedding - ): - _args = self.tts_config.model_args - if _args.use_d_vector_file: - self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["d_vector_file"] - _args["d_vector_file"] = self.tts_speakers_file - else: - self.tts_speakers_file = self.tts_speakers_file if self.tts_speakers_file else _args["speakers_file"] - def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: """Load the TTS model. + 1. Load the model config. + 2. Init the AudioProcessor. + 3. Init the model from the config. + 4. Move the model to the GPU if CUDA is enabled. + 5. Init the speaker manager for the model. + Args: tts_checkpoint (str): path to the model checkpoint. tts_config_path (str): path to the model config file. @@ -148,11 +111,34 @@ class Synthesizer(object): self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) if use_cuda: self.tts_model.cuda() - self._set_tts_speaker_file() + speaker_manager = self._init_speaker_manager() + self.tts_model.speaker_manager = speaker_manager + + def _init_speaker_manager(self): + """Initialize the SpeakerManager""" + # setup if multi-speaker settings are in the global model config + speaker_manager = None + if hasattr(self.tts_config, "use_speaker_embedding") and self.tts_config.use_speaker_embedding is True: + if self.tts_speakers_file: + speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_speakers_file) + if self.tts_config.get("speakers_file", None): + speaker_manager = SpeakerManager(speaker_id_file_path=self.tts_config.speakers_file) + + if hasattr(self.tts_config, "use_d_vector_file") and self.tts_config.use_speaker_embedding is True: + if self.tts_speakers_file: + speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_speakers_file) + if self.tts_config.get("d_vector_file", None): + speaker_manager = SpeakerManager(d_vectors_file_path=self.tts_config.d_vector_file) + return speaker_manager def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: """Load the vocoder model. + 1. Load the vocoder config. + 2. Init the AudioProcessor for the vocoder. + 3. Init the vocoder model from the config. + 4. Move the model to the GPU if CUDA is enabled. + Args: model_file (str): path to the model checkpoint. model_config (str): path to the model config file. @@ -207,7 +193,7 @@ class Synthesizer(object): # handle multi-speaker speaker_embedding = None speaker_id = None - if self.tts_speakers_file: + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "speaker_ids"): if speaker_idx and isinstance(speaker_idx, str): if self.tts_config.use_d_vector_file: # get the speaker embedding from the saved d_vectors. @@ -226,7 +212,7 @@ class Synthesizer(object): else: if speaker_idx: raise ValueError( - f" [!] Missing speaker.json file path for selecting speaker {speaker_idx}." + f" [!] Missing speakers.json file path for selecting speaker {speaker_idx}." "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " )