Update BaseTTS for multi-speaker training

This commit is contained in:
Eren Gölge 2021-10-21 16:12:22 +00:00
parent e62d3c5cf7
commit 2b7d159383
1 changed files with 13 additions and 17 deletions

View File

@ -80,14 +80,12 @@ class BaseTTS(BaseModel):
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
""" """
# init speaker manager
if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file):
raise ValueError(
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
)
# set number of speakers # set number of speakers
if self.speaker_manager is not None: if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers self.num_speakers = self.speaker_manager.num_speakers
elif hasattr(config, "num_speakers"):
self.num_speakers = config.num_speakers
# set ultimate speaker embedding size # set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file: if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = ( self.embedded_speaker_dim = (
@ -189,13 +187,9 @@ class BaseTTS(BaseModel):
ap = assets["audio_processor"] ap = assets["audio_processor"]
# setup multi-speaker attributes # setup multi-speaker attributes
if hasattr(self, "speaker_manager"): if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = ( d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
self.speaker_manager.d_vectors
if config.use_speaker_embedding and config.use_d_vector_file
else None
)
else: else:
speaker_id_mapping = None speaker_id_mapping = None
d_vector_mapping = None d_vector_mapping = None
@ -228,9 +222,7 @@ class BaseTTS(BaseModel):
use_noise_augment=not is_eval, use_noise_augment=not is_eval,
verbose=verbose, verbose=verbose,
speaker_id_mapping=speaker_id_mapping, speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
if config.use_speaker_embedding and config.use_d_vector_file
else None,
) )
# pre-compute phonemes # pre-compute phonemes
@ -292,13 +284,17 @@ class BaseTTS(BaseModel):
def _get_test_aux_input( def _get_test_aux_input(
self, self,
) -> Dict: ) -> Dict:
d_vector = None
if self.config.use_d_vector_file:
d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors]
d_vector = (random.sample(sorted(d_vector), 1),)
aux_inputs = { aux_inputs = {
"speaker_id": None "speaker_id": None
if not self.config.use_speaker_embedding if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1), else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
"d_vector": None "d_vector": d_vector,
if not self.config.use_d_vector_file
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
"style_wav": None, # TODO: handle GST style input "style_wav": None, # TODO: handle GST style input
} }
return aux_inputs return aux_inputs