mirror of https://github.com/coqui-ai/TTS.git
Update BaseTTS for multi-speaker training
This commit is contained in:
parent
e62d3c5cf7
commit
2b7d159383
|
@ -80,14 +80,12 @@ class BaseTTS(BaseModel):
|
|||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
# init speaker manager
|
||||
if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file):
|
||||
raise ValueError(
|
||||
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||
)
|
||||
# set number of speakers
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
elif hasattr(config, "num_speakers"):
|
||||
self.num_speakers = config.num_speakers
|
||||
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
|
@ -189,13 +187,9 @@ class BaseTTS(BaseModel):
|
|||
ap = assets["audio_processor"]
|
||||
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager"):
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = (
|
||||
self.speaker_manager.d_vectors
|
||||
if config.use_speaker_embedding and config.use_d_vector_file
|
||||
else None
|
||||
)
|
||||
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||
else:
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
@ -228,9 +222,7 @@ class BaseTTS(BaseModel):
|
|||
use_noise_augment=not is_eval,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping
|
||||
if config.use_speaker_embedding and config.use_d_vector_file
|
||||
else None,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
)
|
||||
|
||||
# pre-compute phonemes
|
||||
|
@ -292,13 +284,17 @@ class BaseTTS(BaseModel):
|
|||
def _get_test_aux_input(
|
||||
self,
|
||||
) -> Dict:
|
||||
|
||||
d_vector = None
|
||||
if self.config.use_d_vector_file:
|
||||
d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors]
|
||||
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||
|
||||
aux_inputs = {
|
||||
"speaker_id": None
|
||||
if not self.config.use_speaker_embedding
|
||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||
"d_vector": None
|
||||
if not self.config.use_d_vector_file
|
||||
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
||||
"d_vector": d_vector,
|
||||
"style_wav": None, # TODO: handle GST style input
|
||||
}
|
||||
return aux_inputs
|
||||
|
|
Loading…
Reference in New Issue