diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 9d2fceeb..a17e1b2b 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -3,11 +3,13 @@ import torch from coqpit import Coqpit from torch import nn +from torch.cuda.amp.autocast_mode import autocast from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -15,11 +17,17 @@ class Tacotron(BaseTacotron): """Tacotron as in https://arxiv.org/abs/1703.10135 It's an autoregressive encoder-attention-decoder-postnet architecture. Check `TacotronConfig` for the arguments. + + Args: + config (TacotronConfig): Configuration for the Tacotron model. + speaker_manager (SpeakerManager): Speaker manager to handle multi-speaker settings. Only use if the model is + a multi-speaker model. Defaults to None. """ - def __init__(self, config: Coqpit): + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None): super().__init__(config) + self.speaker_manager = speaker_manager chars, self.config, _ = self.get_characters(config) config.num_chars = self.num_chars = len(chars) @@ -240,21 +248,22 @@ class Tacotron(BaseTacotron): outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) # compute loss - loss_dict = criterion( - outputs["model_outputs"], - outputs["decoder_outputs"], - mel_input, - linear_input, - outputs["stop_tokens"], - stop_targets, - stop_target_lengths, - mel_lengths, - outputs["decoder_outputs_backward"], - outputs["alignments"], - alignment_lengths, - outputs["alignments_backward"], - text_lengths, - ) + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + linear_input.float(), + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + mel_lengths, + outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + outputs["alignments_backward"].float(), + text_lengths, + ) # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(outputs["alignments"]) @@ -263,17 +272,23 @@ class Tacotron(BaseTacotron): def _create_logs(self, batch, outputs, ap): postnet_outputs = outputs["model_outputs"] + decoder_outputs = outputs["decoder_outputs"] alignments = outputs["alignments"] alignments_backward = outputs["alignments_backward"] mel_input = batch["mel_input"] + linear_input = batch["linear_input"] - pred_spec = postnet_outputs[0].data.cpu().numpy() - gt_spec = mel_input[0].data.cpu().numpy() + pred_linear_spec = postnet_outputs[0].data.cpu().numpy() + pred_mel_spec = decoder_outputs[0].data.cpu().numpy() + gt_linear_spec = linear_input[0].data.cpu().numpy() + gt_mel_spec = mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { - "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), - "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "pred_linear_spec": plot_spectrogram(pred_linear_spec, ap, output_fig=False), + "real_linear_spec": plot_spectrogram(gt_linear_spec, ap, output_fig=False), + "pred_mel_spec": plot_spectrogram(pred_mel_spec, ap, output_fig=False), + "real_mel_spec": plot_spectrogram(gt_mel_spec, ap, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False), } @@ -281,7 +296,7 @@ class Tacotron(BaseTacotron): figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) # Sample audio - audio = ap.inv_spectrogram(pred_spec.T) + audio = ap.inv_spectrogram(pred_linear_spec.T) return figures, {"audio": audio} def train_log( diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 6b695e2d..e2ae8532 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -3,22 +3,45 @@ import torch from coqpit import Coqpit from torch import nn +from torch.cuda.amp.autocast_mode import autocast from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.visual import plot_alignment, plot_spectrogram class Tacotron2(BaseTacotron): - """Tacotron2 as in https://arxiv.org/abs/1712.05884 - Check `TacotronConfig` for the arguments. + """Tacotron2 model implementation inherited from :class:`TTS.tts.models.base_tacotron.BaseTacotron`. + + Paper:: + https://arxiv.org/abs/1712.05884 + + Paper abstract:: + This paper describes Tacotron 2, a neural network architecture for speech synthesis directly from text. + The system is composed of a recurrent sequence-to-sequence feature prediction network that maps character + embeddings to mel-scale spectrograms, followed by a modified WaveNet model acting as a vocoder to synthesize + timedomain waveforms from those spectrograms. Our model achieves a mean opinion score (MOS) of 4.53 comparable + to a MOS of 4.58 for professionally recorded speech. To validate our design choices, we present ablation + studies of key components of our system and evaluate the impact of using mel spectrograms as the input to + WaveNet instead of linguistic, duration, and F0 features. We further demonstrate that using a compact acoustic + intermediate representation enables significant simplification of the WaveNet architecture. + + Check :class:`TTS.tts.configs.tacotron2_config.Tacotron2Config` for model arguments. + + Args: + config (TacotronConfig): + Configuration for the Tacotron2 model. + speaker_manager (SpeakerManager): + Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. """ - def __init__(self, config: Coqpit): + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None): super().__init__(config) + self.speaker_manager = speaker_manager chars, self.config, _ = self.get_characters(config) config.num_chars = len(chars) self.decoder_output_dim = config.out_channels @@ -28,9 +51,7 @@ class Tacotron2(BaseTacotron): for key in config: setattr(self, key, config[key]) - # set speaker embedding channel size for determining `in_channels` for the connected layers. - # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based - # on the number of speakers infered from the dataset. + # init multi-speaker layers if self.use_speaker_embedding or self.use_d_vector_file: self.init_multispeaker(config) self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim @@ -100,6 +121,7 @@ class Tacotron2(BaseTacotron): @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): + """Final reshape of the model output tensors.""" mel_outputs = mel_outputs.transpose(1, 2) mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments @@ -107,7 +129,8 @@ class Tacotron2(BaseTacotron): def forward( # pylint: disable=dangerous-default-value self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} ): - """ + """Forward pass for training with Teacher Forcing. + Shapes: text: [B, T_in] text_lengths: [B] @@ -174,6 +197,12 @@ class Tacotron2(BaseTacotron): @torch.no_grad() def inference(self, text, aux_input=None): + """Forward pass for inference with no Teacher-Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + """ aux_input = self._format_aux_input(aux_input) embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) @@ -208,7 +237,7 @@ class Tacotron2(BaseTacotron): return outputs def train_step(self, batch, criterion): - """Perform a single training step by fetching the right set if samples from the batch. + """A single training step. Forward pass and loss computation. Args: batch ([type]): [description] @@ -218,7 +247,6 @@ class Tacotron2(BaseTacotron): text_lengths = batch["text_lengths"] mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] - linear_input = batch["linear_input"] stop_targets = batch["stop_targets"] stop_target_lengths = batch["stop_target_lengths"] speaker_ids = batch["speaker_ids"] @@ -245,21 +273,22 @@ class Tacotron2(BaseTacotron): outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) # compute loss - loss_dict = criterion( - outputs["model_outputs"], - outputs["decoder_outputs"], - mel_input, - linear_input, - outputs["stop_tokens"], - stop_targets, - stop_target_lengths, - mel_lengths, - outputs["decoder_outputs_backward"], - outputs["alignments"], - alignment_lengths, - outputs["alignments_backward"], - text_lengths, - ) + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + None, + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(outputs["alignments"]) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index c738f50f..7561780f 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -217,12 +217,13 @@ class Vits(BaseTTS): # pylint: disable=dangerous-default-value - def __init__(self, config: Coqpit): + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None): super().__init__(config) self.END2END = True + self.speaker_manager = speaker_manager if config.__class__.__name__ == "VitsConfig": # loading from VitsConfig if "num_chars" not in config: @@ -314,7 +315,7 @@ class Vits(BaseTTS): if args.init_discriminator: self.disc = VitsDiscriminator(use_spectral_norm=args.use_spectral_norm_disriminator) - def init_multispeaker(self, config: Coqpit, data: List = None): + def init_multispeaker(self, config: Coqpit): """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer or with external `d_vectors` computed from a speaker encoder model. @@ -351,18 +352,6 @@ class Vits(BaseTTS): self.speaker_manager = SpeakerManager(d_vectors_file_path=config.d_vector_file) self.embedded_speaker_dim = config.d_vector_dim - def on_init_start(self, trainer): - """Save the speaker.json at the beginning of the training. And update the config.json with the - speakers.json file path.""" - if self.speaker_manager is not None: - output_path = os.path.join(trainer.output_path, "speakers.json") - self.speaker_manager.save_speaker_ids_to_file(output_path) - trainer.config.speakers_file = output_path - trainer.config.model_args.speakers_file = output_path - trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) - print(f" > `speakers.json` is saved to {output_path}.") - print(" > `speakers_file` is updated in the config.json.") - @staticmethod def _set_cond_input(aux_input: Dict): """Set the speaker conditioning input based on the multi-speaker mode.""" diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index f5fb1d7f..19a16e5e 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -108,6 +108,8 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method class AudioProcessor(object): """Audio Processor for TTS used by all the data pipelines. + TODO: Make this a dataclass to replace `BaseAudioConfig`. + Note: All the class arguments are set to default values to enable a flexible initialization of the class with the model config. They are not meaningful for all the arguments.