mirror of https://github.com/coqui-ai/TTS.git
Update `base_tts.py`
Enable calling `make_symbols()` from the model if defined. Compatibility changes for end2end `tts` models in batch formatting. Changes in multi-speaker initialization. Modify `test_run()` to work with dict output iof `synthesis`
This commit is contained in:
parent
bf562cf437
commit
01324c8e70
|
@ -48,10 +48,17 @@ class BaseTTS(BaseModel):
|
|||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||
`in_channels` size of the connected layers.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
This implementation yields 3 possible outcomes:
|
||||
|
||||
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||
`config.d_vector_dim` or 512.
|
||||
|
||||
You can override this function for new models.0
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
|
@ -59,12 +66,24 @@ class BaseTTS(BaseModel):
|
|||
"""
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
|
||||
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
|
||||
if data is not None or self.speaker_manager.speaker_ids:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
else:
|
||||
self.num_speakers = (
|
||||
config.num_speakers
|
||||
if "num_speakers" in config and config.num_speakers != 0
|
||||
else self.speaker_manager.num_speakers
|
||||
)
|
||||
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
|
@ -87,7 +106,7 @@ class BaseTTS(BaseModel):
|
|||
text_input = batch[0]
|
||||
text_lengths = batch[1]
|
||||
speaker_names = batch[2]
|
||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
||||
linear_input = batch[3]
|
||||
mel_input = batch[4]
|
||||
mel_lengths = batch[5]
|
||||
stop_targets = batch[6]
|
||||
|
@ -95,6 +114,7 @@ class BaseTTS(BaseModel):
|
|||
d_vectors = batch[8]
|
||||
speaker_ids = batch[9]
|
||||
attn_mask = batch[10]
|
||||
waveform = batch[11]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
|
@ -140,6 +160,7 @@ class BaseTTS(BaseModel):
|
|||
"max_text_length": float(max_text_length),
|
||||
"max_spec_length": float(max_spec_length),
|
||||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
}
|
||||
|
||||
def get_data_loader(
|
||||
|
@ -160,15 +181,22 @@ class BaseTTS(BaseModel):
|
|||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
||||
# setup custom symbols if needed
|
||||
custom_symbols = None
|
||||
if hasattr(self, "make_symbols"):
|
||||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron",
|
||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
characters=config.characters,
|
||||
custom_symbols=custom_symbols,
|
||||
add_blank=config["add_blank"],
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
|
@ -220,7 +248,7 @@ class BaseTTS(BaseModel):
|
|||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self.get_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, model_outputs, _ = synthesis(
|
||||
outputs_dict = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
|
@ -232,9 +260,12 @@ class BaseTTS(BaseModel):
|
|||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
)
|
||||
return test_figures, test_audios
|
||||
|
|
Loading…
Reference in New Issue