mirror of https://github.com/coqui-ai/TTS.git
`tts` model abstraction with `TTSModel`
This commit is contained in:
parent
d4dbd89752
commit
6d7b5fbcde
|
@ -0,0 +1,134 @@
|
|||
from coqpit import Coqpit
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class TTSModel(nn.Module, ABC):
|
||||
"""Abstract TTS class. Every new `tts` model must inherit this.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict:
|
||||
"""Forward pass for the model mainly used in training.
|
||||
|
||||
You can be flexible here and use different number of arguments and argument names since it is mostly used by
|
||||
`train_step()` in training whitout exposing it to the out of the class.
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): Input text character sequence ids.
|
||||
aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
|
||||
for the model.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model.
|
||||
"""
|
||||
outputs_dict = {"model_outputs": None}
|
||||
...
|
||||
return outputs_dict
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, text: torch.Tensor, aux_input={}) -> Dict:
|
||||
"""Forward pass for inference.
|
||||
|
||||
After the model is trained this is the only function that connects the model the out world.
|
||||
|
||||
This function must only take a `text` input and a dictionary that has all the other model specific inputs.
|
||||
We don't use `*kwargs` since it is problematic with the TorchScript API.
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): [description]
|
||||
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
"""
|
||||
outputs_dict = {"model_outputs": None}
|
||||
...
|
||||
return outputs_dict
|
||||
|
||||
@abstractmethod
|
||||
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
criterion (nn.Module): Loss layer designed for the model.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
"""
|
||||
outputs_dict = {}
|
||||
loss_dict = {} # this returns from the criterion
|
||||
...
|
||||
return outputs_dict, loss_dict
|
||||
|
||||
@abstractmethod
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
"""Create visualizations and waveform examples for training.
|
||||
|
||||
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
||||
be projected onto Tensorboard.
|
||||
|
||||
Args:
|
||||
ap (AudioProcessor): audio processor used at training.
|
||||
batch (Dict): Model inputs used at the previous training step.
|
||||
outputs (Dict): Model outputs generated at the previoud training step.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
figures_dict = {}
|
||||
output_wav = np.array()
|
||||
...
|
||||
return figures_dict, output_wav
|
||||
|
||||
@abstractmethod
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can
|
||||
call `train_step()` with no changes.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
criterion (nn.Module): Loss layer designed for the model.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
"""
|
||||
outputs_dict = {}
|
||||
loss_dict = {} # this returns from the criterion
|
||||
...
|
||||
return outputs_dict, loss_dict
|
||||
|
||||
@abstractmethod
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
"""The same as `train_log()`"""
|
||||
figures_dict = {}
|
||||
output_wav = np.array()
|
||||
...
|
||||
return figures_dict, output_wav
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
|
||||
"""Load a checkpoint and get ready for training or inference.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
checkpoint_path (str): Path to the model checkpoint file.
|
||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
"""
|
||||
...
|
|
@ -7,13 +7,14 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.models.abstract_tts import TTSModel
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class AlignTTS(nn.Module):
|
||||
class AlignTTS(TTSModel):
|
||||
"""AlignTTS with modified duration predictor.
|
||||
https://arxiv.org/pdf/2003.01950.pdf
|
||||
|
||||
|
|
|
@ -7,13 +7,14 @@ from torch.nn import functional as F
|
|||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.models.abstract_tts import TTSModel
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class GlowTTS(nn.Module):
|
||||
class GlowTTS(TTSModel):
|
||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
||||
|
||||
Args:
|
||||
|
|
|
@ -6,13 +6,14 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
||||
from TTS.tts.models.abstract_tts import TTSModel
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class SpeedySpeech(nn.Module):
|
||||
class SpeedySpeech(TTSModel):
|
||||
"""Speedy Speech model
|
||||
https://arxiv.org/abs/2008.03802
|
||||
|
||||
|
|
Loading…
Reference in New Issue