diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 11f4b7cc..e441cc05 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -48,10 +48,17 @@ class BaseTTS(BaseModel): return get_speaker_manager(config, restore_path, data, out_path) def init_multispeaker(self, config: Coqpit, data: List = None): - """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer - or with external `d_vectors` computed from a speaker encoder model. + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. - If you need a different behaviour, override this function for your model. + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models.0 Args: config (Coqpit): Model configuration. @@ -59,12 +66,24 @@ class BaseTTS(BaseModel): """ # init speaker manager self.speaker_manager = get_speaker_manager(config, data=data) - self.num_speakers = self.speaker_manager.num_speakers - # init speaker embedding layer - if config.use_speaker_embedding and not config.use_d_vector_file: + + # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager + if data is not None or self.speaker_manager.speaker_ids: + self.num_speakers = self.speaker_manager.num_speakers + else: + self.num_speakers = ( + config.num_speakers + if "num_speakers" in config and config.num_speakers != 0 + else self.speaker_manager.num_speakers + ) + + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: self.embedded_speaker_dim = ( config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 ) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) @@ -87,7 +106,7 @@ class BaseTTS(BaseModel): text_input = batch[0] text_lengths = batch[1] speaker_names = batch[2] - linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None + linear_input = batch[3] mel_input = batch[4] mel_lengths = batch[5] stop_targets = batch[6] @@ -95,6 +114,7 @@ class BaseTTS(BaseModel): d_vectors = batch[8] speaker_ids = batch[9] attn_mask = batch[10] + waveform = batch[11] max_text_length = torch.max(text_lengths.float()) max_spec_length = torch.max(mel_lengths.float()) @@ -140,6 +160,7 @@ class BaseTTS(BaseModel): "max_text_length": float(max_text_length), "max_spec_length": float(max_spec_length), "item_idx": item_idx, + "waveform": waveform, } def get_data_loader( @@ -160,15 +181,22 @@ class BaseTTS(BaseModel): speaker_id_mapping = None d_vector_mapping = None + # setup custom symbols if needed + custom_symbols = None + if hasattr(self, "make_symbols"): + custom_symbols = self.make_symbols(self.config) + # init dataloader dataset = TTSDataset( outputs_per_step=config.r if "r" in config else 1, text_cleaner=config.text_cleaner, - compute_linear_spec=config.model.lower() == "tacotron", + compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, meta_data=data_items, ap=ap, characters=config.characters, + custom_symbols=custom_symbols, add_blank=config["add_blank"], + return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, min_seq_len=config.min_seq_len, max_seq_len=config.max_seq_len, @@ -220,7 +248,7 @@ class BaseTTS(BaseModel): test_sentences = self.config.test_sentences aux_inputs = self.get_aux_input() for idx, sen in enumerate(test_sentences): - wav, alignment, model_outputs, _ = synthesis( + outputs_dict = synthesis( self, sen, self.config, @@ -232,9 +260,12 @@ class BaseTTS(BaseModel): enable_eos_bos_chars=self.config.enable_eos_bos_chars, use_griffin_lim=True, do_trim_silence=False, - ).values() - - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, ap, output_fig=False) - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) return test_figures, test_audios