Update BaseTTS for multi-speaker training

This commit is contained in:
Eren Gölge 2021-10-21 16:12:22 +00:00
parent e62d3c5cf7
commit 2b7d159383
1 changed files with 13 additions and 17 deletions

View File

@ -80,14 +80,12 @@ class BaseTTS(BaseModel):
Args:
config (Coqpit): Model configuration.
"""
# init speaker manager
if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file):
raise ValueError(
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
)
# set number of speakers
if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers
elif hasattr(config, "num_speakers"):
self.num_speakers = config.num_speakers
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = (
@ -189,13 +187,9 @@ class BaseTTS(BaseModel):
ap = assets["audio_processor"]
# setup multi-speaker attributes
if hasattr(self, "speaker_manager"):
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = (
self.speaker_manager.d_vectors
if config.use_speaker_embedding and config.use_d_vector_file
else None
)
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
else:
speaker_id_mapping = None
d_vector_mapping = None
@ -228,9 +222,7 @@ class BaseTTS(BaseModel):
use_noise_augment=not is_eval,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping
if config.use_speaker_embedding and config.use_d_vector_file
else None,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
)
# pre-compute phonemes
@ -292,13 +284,17 @@ class BaseTTS(BaseModel):
def _get_test_aux_input(
self,
) -> Dict:
d_vector = None
if self.config.use_d_vector_file:
d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors]
d_vector = (random.sample(sorted(d_vector), 1),)
aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
"d_vector": None
if not self.config.use_d_vector_file
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
"d_vector": d_vector,
"style_wav": None, # TODO: handle GST style input
}
return aux_inputs