mirror of https://github.com/coqui-ai/TTS.git
Update BaseTTS for multi-speaker training
This commit is contained in:
parent
e62d3c5cf7
commit
2b7d159383
|
@ -80,14 +80,12 @@ class BaseTTS(BaseModel):
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
"""
|
"""
|
||||||
# init speaker manager
|
|
||||||
if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file):
|
|
||||||
raise ValueError(
|
|
||||||
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
|
||||||
)
|
|
||||||
# set number of speakers
|
# set number of speakers
|
||||||
if self.speaker_manager is not None:
|
if self.speaker_manager is not None:
|
||||||
self.num_speakers = self.speaker_manager.num_speakers
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
elif hasattr(config, "num_speakers"):
|
||||||
|
self.num_speakers = config.num_speakers
|
||||||
|
|
||||||
# set ultimate speaker embedding size
|
# set ultimate speaker embedding size
|
||||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||||
self.embedded_speaker_dim = (
|
self.embedded_speaker_dim = (
|
||||||
|
@ -189,13 +187,9 @@ class BaseTTS(BaseModel):
|
||||||
ap = assets["audio_processor"]
|
ap = assets["audio_processor"]
|
||||||
|
|
||||||
# setup multi-speaker attributes
|
# setup multi-speaker attributes
|
||||||
if hasattr(self, "speaker_manager"):
|
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||||
d_vector_mapping = (
|
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
|
||||||
self.speaker_manager.d_vectors
|
|
||||||
if config.use_speaker_embedding and config.use_d_vector_file
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
speaker_id_mapping = None
|
speaker_id_mapping = None
|
||||||
d_vector_mapping = None
|
d_vector_mapping = None
|
||||||
|
@ -228,9 +222,7 @@ class BaseTTS(BaseModel):
|
||||||
use_noise_augment=not is_eval,
|
use_noise_augment=not is_eval,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
speaker_id_mapping=speaker_id_mapping,
|
speaker_id_mapping=speaker_id_mapping,
|
||||||
d_vector_mapping=d_vector_mapping
|
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||||
if config.use_speaker_embedding and config.use_d_vector_file
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# pre-compute phonemes
|
# pre-compute phonemes
|
||||||
|
@ -292,13 +284,17 @@ class BaseTTS(BaseModel):
|
||||||
def _get_test_aux_input(
|
def _get_test_aux_input(
|
||||||
self,
|
self,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
|
|
||||||
|
d_vector = None
|
||||||
|
if self.config.use_d_vector_file:
|
||||||
|
d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors]
|
||||||
|
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||||
|
|
||||||
aux_inputs = {
|
aux_inputs = {
|
||||||
"speaker_id": None
|
"speaker_id": None
|
||||||
if not self.config.use_speaker_embedding
|
if not self.config.use_speaker_embedding
|
||||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||||
"d_vector": None
|
"d_vector": d_vector,
|
||||||
if not self.config.use_d_vector_file
|
|
||||||
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
|
||||||
"style_wav": None, # TODO: handle GST style input
|
"style_wav": None, # TODO: handle GST style input
|
||||||
}
|
}
|
||||||
return aux_inputs
|
return aux_inputs
|
||||||
|
|
Loading…
Reference in New Issue