diff --git a/TTS/tts/models/abstract_tts.py b/TTS/tts/models/abstract_tts.py new file mode 100644 index 00000000..9132f7eb --- /dev/null +++ b/TTS/tts/models/abstract_tts.py @@ -0,0 +1,134 @@ +from coqpit import Coqpit +from abc import ABC, abstractmethod +from typing import Dict, Tuple + +import numpy as np +import torch +from torch import nn + +from TTS.utils.audio import AudioProcessor + +# pylint: skip-file + + +class TTSModel(nn.Module, ABC): + """Abstract TTS class. Every new `tts` model must inherit this. + + Notes on input/output tensor shapes: + Any input or output tensor of the model must be shaped as + + - 3D tensors `batch x time x channels` + - 2D tensors `batch x channels` + - 1D tensors `batch x 1` + """ + + @abstractmethod + def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict: + """Forward pass for the model mainly used in training. + + You can be flexible here and use different number of arguments and argument names since it is mostly used by + `train_step()` in training whitout exposing it to the out of the class. + + Args: + text (torch.Tensor): Input text character sequence ids. + aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs. + for the model. + + Returns: + Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model. + """ + outputs_dict = {"model_outputs": None} + ... + return outputs_dict + + @abstractmethod + def inference(self, text: torch.Tensor, aux_input={}) -> Dict: + """Forward pass for inference. + + After the model is trained this is the only function that connects the model the out world. + + This function must only take a `text` input and a dictionary that has all the other model specific inputs. + We don't use `*kwargs` since it is problematic with the TorchScript API. + + Args: + text (torch.Tensor): [description] + aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc. + + Returns: + Dict: [description] + """ + outputs_dict = {"model_outputs": None} + ... + return outputs_dict + + @abstractmethod + def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + outputs_dict = {} + loss_dict = {} # this returns from the criterion + ... + return outputs_dict, loss_dict + + @abstractmethod + def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + """Create visualizations and waveform examples for training. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + ap (AudioProcessor): audio processor used at training. + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previoud training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures_dict = {} + output_wav = np.array() + ... + return figures_dict, output_wav + + @abstractmethod + def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: + """Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can + call `train_step()` with no changes. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + outputs_dict = {} + loss_dict = {} # this returns from the criterion + ... + return outputs_dict, loss_dict + + @abstractmethod + def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: + """The same as `train_log()`""" + figures_dict = {} + output_wav = np.array() + ... + return figures_dict, output_wav + + @abstractmethod + def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: + """Load a checkpoint and get ready for training or inference. + + Args: + config (Coqpit): Model configuration. + checkpoint_path (str): Path to the model checkpoint file. + eval (bool, optional): If true, init model for inference else for training. Defaults to False. + """ + ... diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 6c268a43..75fb50de 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -7,13 +7,14 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path +from TTS.tts.models.abstract_tts import TTSModel from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -class AlignTTS(nn.Module): +class AlignTTS(TTSModel): """AlignTTS with modified duration predictor. https://arxiv.org/pdf/2003.01950.pdf diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index e61b80c2..a30eadb4 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -7,13 +7,14 @@ from torch.nn import functional as F from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path +from TTS.tts.models.abstract_tts import TTSModel from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -class GlowTTS(nn.Module): +class GlowTTS(TTSModel): """Glow TTS models from https://arxiv.org/abs/2005.11129 Args: diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index d4a90a2e..44a47722 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -6,13 +6,14 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path +from TTS.tts.models.abstract_tts import TTSModel from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -class SpeedySpeech(nn.Module): +class SpeedySpeech(TTSModel): """Speedy Speech model https://arxiv.org/abs/2008.03802