Update `base_tts.py`

Enable calling `make_symbols()` from the model if defined.
Compatibility changes for end2end `tts` models in batch formatting.
Changes in multi-speaker initialization.
Modify `test_run()` to work with dict output iof `synthesis`
This commit is contained in:
Eren Gölge 2021-08-07 21:35:04 +00:00
parent bf562cf437
commit 01324c8e70
1 changed files with 45 additions and 14 deletions

View File

@ -48,10 +48,17 @@ class BaseTTS(BaseModel):
return get_speaker_manager(config, restore_path, data, out_path)
def init_multispeaker(self, config: Coqpit, data: List = None):
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
or with external `d_vectors` computed from a speaker encoder model.
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
`in_channels` size of the connected layers.
If you need a different behaviour, override this function for your model.
This implementation yields 3 possible outcomes:
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
`config.d_vector_dim` or 512.
You can override this function for new models.0
Args:
config (Coqpit): Model configuration.
@ -59,12 +66,24 @@ class BaseTTS(BaseModel):
"""
# init speaker manager
self.speaker_manager = get_speaker_manager(config, data=data)
self.num_speakers = self.speaker_manager.num_speakers
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
if data is not None or self.speaker_manager.speaker_ids:
self.num_speakers = self.speaker_manager.num_speakers
else:
self.num_speakers = (
config.num_speakers
if "num_speakers" in config and config.num_speakers != 0
else self.speaker_manager.num_speakers
)
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = (
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
@ -87,7 +106,7 @@ class BaseTTS(BaseModel):
text_input = batch[0]
text_lengths = batch[1]
speaker_names = batch[2]
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
linear_input = batch[3]
mel_input = batch[4]
mel_lengths = batch[5]
stop_targets = batch[6]
@ -95,6 +114,7 @@ class BaseTTS(BaseModel):
d_vectors = batch[8]
speaker_ids = batch[9]
attn_mask = batch[10]
waveform = batch[11]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@ -140,6 +160,7 @@ class BaseTTS(BaseModel):
"max_text_length": float(max_text_length),
"max_spec_length": float(max_spec_length),
"item_idx": item_idx,
"waveform": waveform,
}
def get_data_loader(
@ -160,15 +181,22 @@ class BaseTTS(BaseModel):
speaker_id_mapping = None
d_vector_mapping = None
# setup custom symbols if needed
custom_symbols = None
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
# init dataloader
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron",
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
meta_data=data_items,
ap=ap,
characters=config.characters,
custom_symbols=custom_symbols,
add_blank=config["add_blank"],
return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len,
@ -220,7 +248,7 @@ class BaseTTS(BaseModel):
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
for idx, sen in enumerate(test_sentences):
wav, alignment, model_outputs, _ = synthesis(
outputs_dict = synthesis(
self,
sen,
self.config,
@ -232,9 +260,12 @@ class BaseTTS(BaseModel):
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
use_griffin_lim=True,
do_trim_silence=False,
).values()
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False
)
return test_figures, test_audios