diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index cd84bf08..b77c1e23 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -80,14 +80,12 @@ class BaseTTS(BaseModel): Args: config (Coqpit): Model configuration. """ - # init speaker manager - if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file): - raise ValueError( - " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." - ) # set number of speakers if self.speaker_manager is not None: self.num_speakers = self.speaker_manager.num_speakers + elif hasattr(config, "num_speakers"): + self.num_speakers = config.num_speakers + # set ultimate speaker embedding size if config.use_speaker_embedding or config.use_d_vector_file: self.embedded_speaker_dim = ( @@ -189,13 +187,9 @@ class BaseTTS(BaseModel): ap = assets["audio_processor"] # setup multi-speaker attributes - if hasattr(self, "speaker_manager"): + if hasattr(self, "speaker_manager") and self.speaker_manager is not None: speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None - d_vector_mapping = ( - self.speaker_manager.d_vectors - if config.use_speaker_embedding and config.use_d_vector_file - else None - ) + d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None else: speaker_id_mapping = None d_vector_mapping = None @@ -228,9 +222,7 @@ class BaseTTS(BaseModel): use_noise_augment=not is_eval, verbose=verbose, speaker_id_mapping=speaker_id_mapping, - d_vector_mapping=d_vector_mapping - if config.use_speaker_embedding and config.use_d_vector_file - else None, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, ) # pre-compute phonemes @@ -292,13 +284,17 @@ class BaseTTS(BaseModel): def _get_test_aux_input( self, ) -> Dict: + + d_vector = None + if self.config.use_d_vector_file: + d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors] + d_vector = (random.sample(sorted(d_vector), 1),) + aux_inputs = { "speaker_id": None if not self.config.use_speaker_embedding else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1), - "d_vector": None - if not self.config.use_d_vector_file - else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1), + "d_vector": d_vector, "style_wav": None, # TODO: handle GST style input } return aux_inputs