From 7b8c15ac4915e03936089b352e723256e5baacec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 18 Jun 2021 13:35:36 +0200 Subject: [PATCH] =?UTF-8?q?Create=20base=20=F0=9F=90=B8TTS=20model=20abstr?= =?UTF-8?q?action=20for=20tts=20models?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TTS/{tts/models/abstract_tts.py => model.py} | 41 ++- TTS/tts/models/align_tts.py | 159 +++++++---- TTS/tts/models/base_tacotron.py | 286 +++++++++++++++++++ TTS/tts/models/base_tts.py | 233 +++++++++++++++ TTS/tts/models/glow_tts.py | 155 ++++------ TTS/tts/models/speedy_speech.py | 147 +++++++--- TTS/tts/models/tacotron.py | 216 +++++--------- TTS/tts/models/tacotron2.py | 206 +++++-------- TTS/tts/tf/models/tacotron2.py | 6 +- TTS/vocoder/models/base_vocoder.py | 20 ++ 10 files changed, 968 insertions(+), 501 deletions(-) rename TTS/{tts/models/abstract_tts.py => model.py} (86%) create mode 100644 TTS/tts/models/base_tacotron.py create mode 100644 TTS/tts/models/base_tts.py create mode 100644 TTS/vocoder/models/base_vocoder.py diff --git a/TTS/tts/models/abstract_tts.py b/TTS/model.py similarity index 86% rename from TTS/tts/models/abstract_tts.py rename to TTS/model.py index 9132f7eb..aefb925e 100644 --- a/TTS/tts/models/abstract_tts.py +++ b/TTS/model.py @@ -1,9 +1,9 @@ -from coqpit import Coqpit from abc import ABC, abstractmethod -from typing import Dict, Tuple +from typing import Dict, List, Tuple, Union import numpy as np import torch +from coqpit import Coqpit from torch import nn from TTS.utils.audio import AudioProcessor @@ -11,8 +11,8 @@ from TTS.utils.audio import AudioProcessor # pylint: skip-file -class TTSModel(nn.Module, ABC): - """Abstract TTS class. Every new `tts` model must inherit this. +class BaseModel(nn.Module, ABC): + """Abstract 🐸TTS class. Every new 🐸TTS model must inherit this. Notes on input/output tensor shapes: Any input or output tensor of the model must be shaped as @@ -77,7 +77,6 @@ class TTSModel(nn.Module, ABC): ... return outputs_dict, loss_dict - @abstractmethod def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: """Create visualizations and waveform examples for training. @@ -92,10 +91,7 @@ class TTSModel(nn.Module, ABC): Returns: Tuple[Dict, np.ndarray]: training plots and output waveform. """ - figures_dict = {} - output_wav = np.array() - ... - return figures_dict, output_wav + return None, None @abstractmethod def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: @@ -114,13 +110,9 @@ class TTSModel(nn.Module, ABC): ... return outputs_dict, loss_dict - @abstractmethod def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: """The same as `train_log()`""" - figures_dict = {} - output_wav = np.array() - ... - return figures_dict, output_wav + return None, None @abstractmethod def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None: @@ -132,3 +124,24 @@ class TTSModel(nn.Module, ABC): eval (bool, optional): If true, init model for inference else for training. Defaults to False. """ ... + + def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]: + """Setup an return optimizer or optimizers.""" + pass + + def get_lr(self) -> Union[float, List[float]]: + """Return learning rate(s). + + Returns: + Union[float, List[float]]: Model's initial learning rates. + """ + pass + + def get_scheduler(self, optimizer: torch.optim.Optimizer): + pass + + def get_criterion(self): + pass + + def format_batch(self): + pass diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 75fb50de..dbd57b83 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -1,5 +1,9 @@ +from dataclasses import dataclass, field +from typing import Dict, Tuple + import torch import torch.nn as nn +from coqpit import Coqpit from TTS.tts.layers.align_tts.mdn import MDNBlock from TTS.tts.layers.feed_forward.decoder import Decoder @@ -7,36 +11,16 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path -from TTS.tts.models.abstract_tts import TTSModel +from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -class AlignTTS(TTSModel): - """AlignTTS with modified duration predictor. - https://arxiv.org/pdf/2003.01950.pdf - - Encoder -> DurationPredictor -> Decoder - - AlignTTS's Abstract - Targeting at both high efficiency and performance, we propose AlignTTS to predict the - mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a - sequence of characters, and the duration of each character is determined by a duration predictor.Instead of - adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented - to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s - how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean - option score (MOS), but also a high efficiency which is more than 50 times faster than real-time. - - Note: - Original model uses a separate character embedding layer for duration predictor. However, it causes the - duration predictor to overfit and prevents learning higher level interactions among characters. Therefore, - we predict durations based on encoder outputs which has higher level information about input characters. This - enables training without phases as in the original paper. - - Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture - differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. - +@dataclass +class AlignTTSArgs(Coqpit): + """ Args: num_chars (int): number of unique input to characters @@ -64,43 +48,98 @@ class AlignTTS(TTSModel): number of channels in speaker embedding vectors. Defaults to 0. """ + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 256 + hidden_channels_dp: int = 256 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + length_scale: float = 1.0 + num_speakers: int = 0 + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_dim: int = 0 + + +class AlignTTS(BaseTTS): + """AlignTTS with modified duration predictor. + https://arxiv.org/pdf/2003.01950.pdf + + Encoder -> DurationPredictor -> Decoder + + Check ```AlignTTSArgs``` for the class arguments. + + Examples: + >>> from TTS.tts.configs import AlignTTSConfig + >>> config = AlignTTSConfig() + >>> config.model_args.num_chars = 50 + >>> model = AlignTTS(config) + + Paper Abstract: + Targeting at both high efficiency and performance, we propose AlignTTS to predict the + mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a + sequence of characters, and the duration of each character is determined by a duration predictor.Instead of + adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented + to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s + how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean + option score (MOS), but also a high efficiency which is more than 50 times faster than real-time. + + Note: + Original model uses a separate character embedding layer for duration predictor. However, it causes the + duration predictor to overfit and prevents learning higher level interactions among characters. Therefore, + we predict durations based on encoder outputs which has higher level information about input characters. This + enables training without phases as in the original paper. + + Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture + differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. + + """ + # pylint: disable=dangerous-default-value - def __init__( - self, - num_chars, - out_channels, - hidden_channels=256, - hidden_channels_dp=256, - encoder_type="fftransformer", - encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}, - decoder_type="fftransformer", - decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}, - length_scale=1, - num_speakers=0, - external_c=False, - c_in_channels=0, - ): + def __init__(self, config: Coqpit): super().__init__() + self.config = config self.phase = -1 - self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale - self.emb = nn.Embedding(num_chars, hidden_channels) - self.pos_encoder = PositionalEncoding(hidden_channels) - self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels) - self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params) - self.duration_predictor = DurationPredictor(hidden_channels_dp) + self.length_scale = ( + float(config.model_args.length_scale) + if isinstance(config.model_args.length_scale, int) + else config.model_args.length_scale + ) + self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) - self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1) - self.mdn_block = MDNBlock(hidden_channels, 2 * out_channels) + self.embedded_speaker_dim = 0 + self.init_multispeaker(config) - if num_speakers > 1 and not external_c: - # speaker embedding layer - self.emb_g = nn.Embedding(num_speakers, c_in_channels) - nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + self.embedded_speaker_dim, + ) + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, + ) + self.duration_predictor = DurationPredictor(config.model_args.hidden_channels_dp) - if c_in_channels > 0 and c_in_channels != hidden_channels: - self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1) + self.mod_layer = nn.Conv1d(config.model_args.hidden_channels, config.model_args.hidden_channels, 1) + + self.mdn_block = MDNBlock(config.model_args.hidden_channels, 2 * config.model_args.out_channels) + + if self.embedded_speaker_dim > 0 and self.embedded_speaker_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(self.embedded_speaker_dim, config.model_args.hidden_channels, 1) @staticmethod def compute_log_probs(mu, log_sigma, y): @@ -164,11 +203,12 @@ class AlignTTS(TTSModel): # project g to decoder dim. if hasattr(self, "proj_g"): g = self.proj_g(g) + return x + g def _forward_encoder(self, x, x_lengths, g=None): if hasattr(self, "emb_g"): - g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] + g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1] if g is not None: g = g.unsqueeze(-1) @@ -315,7 +355,9 @@ class AlignTTS(TTSModel): loss_dict["align_error"] = align_error return outputs, loss_dict - def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use + def train_log( + self, ap: AudioProcessor, batch: dict, outputs: dict + ) -> Tuple[Dict, Dict]: # pylint: disable=no-self-use model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] mel_input = batch["mel_input"] @@ -332,7 +374,7 @@ class AlignTTS(TTSModel): # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) - return figures, train_audio + return figures, {"audio": train_audio} def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) @@ -349,6 +391,11 @@ class AlignTTS(TTSModel): self.eval() assert not self.training + def get_criterion(self): + from TTS.tts.layers.losses import AlignTTSLoss # pylint: disable=import-outside-toplevel + + return AlignTTSLoss(self.config) + @staticmethod def _set_phase(config, global_step): """Decide AlignTTS training phase""" diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py new file mode 100644 index 00000000..a99e1926 --- /dev/null +++ b/TTS/tts/models/base_tacotron.py @@ -0,0 +1,286 @@ +import copy +from abc import abstractmethod +from dataclasses import dataclass +from typing import Dict, List + +import torch +from coqpit import MISSING, Coqpit +from torch import nn + +from TTS.tts.layers.losses import TacotronLoss +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager +from TTS.tts.utils.text import make_symbols +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.training import gradual_training_scheduler + + +@dataclass +class BaseTacotronArgs(Coqpit): + """TODO: update Tacotron configs using it""" + + num_chars: int = MISSING + num_speakers: int = MISSING + r: int = MISSING + out_channels: int = 80 + decoder_output_dim: int = 80 + attn_type: str = "original" + attn_win: bool = False + attn_norm: str = "softmax" + prenet_type: str = "original" + prenet_dropout: bool = True + prenet_dropout_at_inference: bool = False + forward_attn: bool = False + trans_agent: bool = False + forward_attn_mask: bool = False + location_attn: bool = True + attn_K: int = 5 + separate_stopnet: bool = True + bidirectional_decoder: bool = False + double_decoder_consistency: bool = False + ddc_r: int = None + encoder_in_features: int = 512 + decoder_in_features: int = 512 + d_vector_dim: int = None + use_gst: bool = False + gst: bool = None + gradual_training: bool = None + + +class BaseTacotron(BaseTTS): + def __init__(self, config: Coqpit): + """Abstract Tacotron class""" + super().__init__() + + for key in config: + setattr(self, key, config[key]) + + # layers + self.embedding = None + self.encoder = None + self.decoder = None + self.postnet = None + + # init tensors + self.embedded_speakers = None + self.embedded_speakers_projected = None + + # global style token + if self.gst and self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim + self.gst_layer = None + + # additional layers + self.decoder_backward = None + self.coarse_decoder = None + + # init multi-speaker layers + self.init_multispeaker(config) + + @staticmethod + def _format_aux_input(aux_input: Dict) -> Dict: + return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + + ############################# + # INIT FUNCTIONS + ############################# + + def _init_states(self): + self.embedded_speakers = None + self.embedded_speakers_projected = None + + def _init_backward_decoder(self): + self.decoder_backward = copy.deepcopy(self.decoder) + + def _init_coarse_decoder(self): + self.coarse_decoder = copy.deepcopy(self.decoder) + self.coarse_decoder.r_init = self.ddc_r + self.coarse_decoder.set_r(self.ddc_r) + + ############################# + # CORE FUNCTIONS + ############################# + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def inference(self): + pass + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if "r" in state: + self.decoder.set_r(state["r"]) + else: + self.decoder.set_r(state["config"]["r"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self) -> nn.Module: + return TacotronLoss(self.config) + + @staticmethod + def get_characters(config: Coqpit) -> str: + # TODO: implement CharacterProcessor + if config.characters is not None: + symbols, phonemes = make_symbols(**config.characters) + else: + from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel + parse_symbols, + phonemes, + symbols, + ) + + config.characters = parse_symbols() + model_characters = phonemes if config.use_phonemes else symbols + return model_characters, config + + @staticmethod + def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: + return get_speaker_manager(config, restore_path, data, out_path) + + def get_aux_input(self, **kwargs) -> Dict: + """Compute Tacotron's auxiliary inputs based on model config. + - speaker d_vector + - style wav for GST + - speaker ID for speaker embedding + """ + # setup speaker_id + if self.config.use_speaker_embedding: + speaker_id = kwargs.get("speaker_id", 0) + else: + speaker_id = None + # setup d_vector + d_vector = ( + self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0]) + if self.config.use_d_vector_file and self.config.use_speaker_embedding + else None + ) + # setup style_mel + if "style_wav" in kwargs: + style_wav = kwargs["style_wav"] + elif self.config.has("gst_style_input"): + style_wav = self.config.gst_style_input + else: + style_wav = None + if style_wav is None and "use_gst" in self.config and self.config.use_gst: + # inicialize GST with zero dict. + style_wav = {} + print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") + for i in range(self.config.gst["gst_num_style_tokens"]): + style_wav[str(i)] = 0 + aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} + return aux_inputs + + ############################# + # COMMON COMPUTE FUNCTIONS + ############################# + + def compute_masks(self, text_lengths, mel_lengths): + """Compute masks against sequence paddings.""" + # B x T_in_max (boolean) + input_mask = sequence_mask(text_lengths) + output_mask = None + if mel_lengths is not None: + max_len = mel_lengths.max() + r = self.decoder.r + max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len + output_mask = sequence_mask(mel_lengths, max_len=max_len) + return input_mask, output_mask + + def _backward_pass(self, mel_specs, encoder_outputs, mask): + """Run backwards decoder""" + decoder_outputs_b, alignments_b, _ = self.decoder_backward( + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask + ) + decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() + return decoder_outputs_b, alignments_b + + def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): + """Double Decoder Consistency""" + T = mel_specs.shape[1] + if T % self.coarse_decoder.r > 0: + padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) + mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) + decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( + encoder_outputs.detach(), mel_specs, input_mask + ) + # scale_factor = self.decoder.r_init / self.decoder.r + alignments_backward = torch.nn.functional.interpolate( + alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest" + ).transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward[:, :T, :] + return decoder_outputs_backward, alignments_backward + + ############################# + # EMBEDDING FUNCTIONS + ############################# + + def compute_speaker_embedding(self, speaker_ids): + """Compute speaker embedding vectors""" + if hasattr(self, "speaker_embedding") and speaker_ids is None: + raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") + if hasattr(self, "speaker_embedding") and speaker_ids is not None: + self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1) + if hasattr(self, "speaker_project_mel") and speaker_ids is not None: + self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1) + + def compute_gst(self, inputs, style_input, speaker_embedding=None): + """Compute global style token""" + if isinstance(style_input, dict): + query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) + if speaker_embedding is not None: + query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) + + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + else: + gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable + inputs = self._concat_speaker_embedding(inputs, gst_outputs) + return inputs + + @staticmethod + def _add_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = outputs + embedded_speakers_ + return outputs + + @staticmethod + def _concat_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = torch.cat([outputs, embedded_speakers_], dim=-1) + return outputs + + ############################# + # CALLBACKS + ############################# + + def on_epoch_start(self, trainer): + """Callback for setting values wrt gradual training schedule. + + Args: + trainer (TrainerTTS): TTS trainer object that is used to train this model. + """ + if self.gradual_training: + r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) + trainer.config.r = r + self.decoder.set_r(r) + if trainer.config.bidirectional_decoder: + trainer.model.decoder_backward.set_r(r) + trainer.train_loader = trainer.setup_train_dataloader(self.ap, self.model.decoder.r, verbose=True) + trainer.eval_loader = trainer.setup_eval_dataloder(self.ap, self.model.decoder.r) + print(f"\n > Number of output frames: {self.decoder.r}") diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py new file mode 100644 index 00000000..1de7ba92 --- /dev/null +++ b/TTS/tts/models/base_tts.py @@ -0,0 +1,233 @@ +from typing import Dict, List, Tuple + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from TTS.model import BaseModel +from TTS.tts.datasets import TTSDataset +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text import make_symbols +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor + +# pylint: skip-file + + +class BaseTTS(BaseModel): + """Abstract `tts` class. Every new `tts` model must inherit this. + + It defines `tts` specific functions on top of `Model`. + + Notes on input/output tensor shapes: + Any input or output tensor of the model must be shaped as + + - 3D tensors `batch x time x channels` + - 2D tensors `batch x channels` + - 1D tensors `batch x 1` + """ + + @staticmethod + def get_characters(config: Coqpit) -> str: + # TODO: implement CharacterProcessor + if config.characters is not None: + symbols, phonemes = make_symbols(**config.characters) + else: + from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols + + config.characters = parse_symbols() + model_characters = phonemes if config.use_phonemes else symbols + return model_characters, config + + def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager: + return get_speaker_manager(config, restore_path, data, out_path) + + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + If you need a different behaviour, override this function for your model. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + # init speaker manager + self.speaker_manager = get_speaker_manager(config, data=data) + self.num_speakers = self.speaker_manager.num_speakers + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + pass + + def format_batch(self, batch: Dict) -> Dict: + """Generic batch formatting for `TTSDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (Dict): [description] + + Returns: + Dict: [description] + """ + # setup input batch + text_input = batch[0] + text_lengths = batch[1] + speaker_names = batch[2] + linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None + mel_input = batch[4] + mel_lengths = batch[5] + stop_targets = batch[6] + item_idx = batch[7] + d_vectors = batch[8] + speaker_ids = batch[9] + attn_mask = batch[10] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets view, we predict a single stop token per iteration. + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + } + + def get_data_loader( + self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + if hasattr(self, "speaker_manager"): + speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None + d_vector_mapping = ( + self.speaker_manager.d_vectors + if config.use_speaker_embedding and config.use_d_vector_file + else None + ) + else: + speaker_id_mapping = None + d_vector_mapping = None + + # init dataloader + dataset = TTSDataset( + outputs_per_step=config.r if "r" in config else 1, + text_cleaner=config.text_cleaner, + compute_linear_spec=config.model.lower() == "tacotron", + meta_data=data_items, + ap=ap, + tp=config.characters, + add_blank=config["add_blank"], + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_seq_len=config.min_seq_len, + max_seq_len=config.max_seq_len, + phoneme_cache_path=config.phoneme_cache_path, + use_phonemes=config.use_phonemes, + phoneme_language=config.phoneme_language, + enable_eos_bos=config.enable_eos_bos_chars, + use_noise_augment=not is_eval, + verbose=verbose, + speaker_id_mapping=speaker_id_mapping, + d_vector_mapping=d_vector_mapping + if config.use_speaker_embedding and config.use_d_vector_file + else None, + ) + + if config.use_phonemes and config.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(config.num_loader_workers) + dataset.sort_items() + + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def test_run(self) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + print(" | > Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_aux_inputs() + for idx, sen in enumerate(test_sentences): + wav, alignment, model_outputs, _ = synthesis( + self.model, + sen, + self.config, + self.use_cuda, + self.ap, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + enable_eos_bos_chars=self.config.enable_eos_bos_chars, + use_griffin_lim=True, + do_trim_silence=False, + ).values() + + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) + return test_figures, test_audios diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index a30eadb4..ca2682dc 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -4,131 +4,89 @@ import torch from torch import nn from torch.nn import functional as F +from TTS.tts.configs import GlowTTSConfig from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.encoder import Encoder from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path -from TTS.tts.models.abstract_tts import TTSModel +from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -class GlowTTS(TTSModel): +class GlowTTS(BaseTTS): """Glow TTS models from https://arxiv.org/abs/2005.11129 - Args: - num_chars (int): number of embedding characters. - hidden_channels_enc (int): number of embedding and encoder channels. - hidden_channels_dec (int): number of decoder channels. - use_encoder_prenet (bool): enable/disable prenet for encoder. Prenet modules are hard-coded for each alternative encoder. - hidden_channels_dp (int): number of duration predictor channels. - out_channels (int): number of output channels. It should be equal to the number of spectrogram filter. - num_flow_blocks_dec (int): number of decoder blocks. - kernel_size_dec (int): decoder kernel size. - dilation_rate (int): rate to increase dilation by each layer in a decoder block. - num_block_layers (int): number of decoder layers in each decoder block. - dropout_p_dec (float): dropout rate for decoder. - num_speaker (int): number of speaker to define the size of speaker embedding layer. - c_in_channels (int): number of speaker embedding channels. It is set to 512 if embeddings are learned. - num_splits (int): number of split levels in inversible conv1x1 operation. - num_squeeze (int): number of squeeze levels. When squeezing channels increases and time steps reduces by the factor 'num_squeeze'. - sigmoid_scale (bool): enable/disable sigmoid scaling in decoder. - mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step. - encoder_type (str): encoder module type. - encoder_params (dict): encoder module parameters. - d_vector_dim (int): channels of external speaker embedding vectors. + Paper abstract: + Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate + mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained + without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS, + a flow-based generative model for parallel TTS that does not require any external aligner. By combining the + properties of flows and dynamic programming, the proposed model searches for the most probable monotonic + alignment between text and the latent representation of speech on its own. We demonstrate that enforcing hard + monotonic alignments enables robust TTS, which generalizes to long utterances, and employing generative flows + enables fast, diverse, and controllable speech synthesis. Glow-TTS obtains an order-of-magnitude speed-up over + the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our + model can be easily extended to a multi-speaker setting. + + Check `GlowTTSConfig` for class arguments. """ - def __init__( - self, - num_chars, - hidden_channels_enc, - hidden_channels_dec, - use_encoder_prenet, - hidden_channels_dp, - out_channels, - num_flow_blocks_dec=12, - inference_noise_scale=0.33, - kernel_size_dec=5, - dilation_rate=5, - num_block_layers=4, - dropout_p_dp=0.1, - dropout_p_dec=0.05, - num_speakers=0, - c_in_channels=0, - num_splits=4, - num_squeeze=1, - sigmoid_scale=False, - mean_only=False, - encoder_type="transformer", - encoder_params=None, - d_vector_dim=None, - ): + def __init__(self, config: GlowTTSConfig): super().__init__() - self.num_chars = num_chars - self.hidden_channels_dp = hidden_channels_dp - self.hidden_channels_enc = hidden_channels_enc - self.hidden_channels_dec = hidden_channels_dec - self.out_channels = out_channels - self.num_flow_blocks_dec = num_flow_blocks_dec - self.kernel_size_dec = kernel_size_dec - self.dilation_rate = dilation_rate - self.num_block_layers = num_block_layers - self.dropout_p_dec = dropout_p_dec - self.num_speakers = num_speakers - self.c_in_channels = c_in_channels - self.num_splits = num_splits - self.num_squeeze = num_squeeze - self.sigmoid_scale = sigmoid_scale - self.mean_only = mean_only - self.use_encoder_prenet = use_encoder_prenet - self.inference_noise_scale = inference_noise_scale - # model constants. - self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference. - self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech. - self.d_vector_dim = d_vector_dim + chars, self.config = self.get_characters(config) + self.num_chars = len(chars) + self.decoder_output_dim = config.out_channels + self.init_multispeaker(config) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) # if is a multispeaker and c_in_channels is 0, set to 256 - if num_speakers > 1: - if self.c_in_channels == 0 and not self.d_vector_dim: + self.c_in_channels = 0 + if self.num_speakers > 1: + if self.d_vector_dim: + self.c_in_channels = self.d_vector_dim + elif self.c_in_channels == 0 and not self.d_vector_dim: # TODO: make this adjustable self.c_in_channels = 256 - elif self.d_vector_dim: - self.c_in_channels = self.d_vector_dim self.encoder = Encoder( - num_chars, - out_channels=out_channels, - hidden_channels=hidden_channels_enc, - hidden_channels_dp=hidden_channels_dp, - encoder_type=encoder_type, - encoder_params=encoder_params, - mean_only=mean_only, - use_prenet=use_encoder_prenet, - dropout_p_dp=dropout_p_dp, + self.num_chars, + out_channels=self.out_channels, + hidden_channels=self.hidden_channels_enc, + hidden_channels_dp=self.hidden_channels_dp, + encoder_type=self.encoder_type, + encoder_params=self.encoder_params, + mean_only=self.mean_only, + use_prenet=self.use_encoder_prenet, + dropout_p_dp=self.dropout_p_dp, c_in_channels=self.c_in_channels, ) self.decoder = Decoder( - out_channels, - hidden_channels_dec, - kernel_size_dec, - dilation_rate, - num_flow_blocks_dec, - num_block_layers, - dropout_p=dropout_p_dec, - num_splits=num_splits, - num_squeeze=num_squeeze, - sigmoid_scale=sigmoid_scale, + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, c_in_channels=self.c_in_channels, ) - if num_speakers > 1 and not d_vector_dim: + if self.num_speakers > 1 and not self.d_vector_dim: # speaker embedding layer - self.emb_g = nn.Embedding(num_speakers, self.c_in_channels) + self.emb_g = nn.Embedding(self.num_speakers, self.c_in_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) @staticmethod @@ -377,7 +335,7 @@ class GlowTTS(TTSModel): # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) - return figures, train_audio + return figures, {"audio": train_audio} def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) @@ -406,3 +364,8 @@ class GlowTTS(TTSModel): self.eval() self.store_inverse() assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel + + return GlowTTSLoss() diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 44a47722..2eb70a6b 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -1,4 +1,7 @@ +from dataclasses import dataclass, field + import torch +from coqpit import Coqpit from torch import nn from TTS.tts.layers.feed_forward.decoder import Decoder @@ -6,25 +9,16 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path -from TTS.tts.models.abstract_tts import TTSModel +from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -class SpeedySpeech(TTSModel): - """Speedy Speech model - https://arxiv.org/abs/2008.03802 - - Encoder -> DurationPredictor -> Decoder - - This model is able to achieve a reasonable performance with only - ~3M model parameters and convolutional layers. - - This model requires precomputed phoneme durations to train a duration predictor. At inference - it only uses the duration predictor to compute durations and expand encoder outputs respectively. - +@dataclass +class SpeedySpeechArgs(Coqpit): + """ Args: num_chars (int): number of unique input to characters out_channels (int): number of output tensor channels. It is equal to the expected spectrogram size. @@ -36,49 +30,107 @@ class SpeedySpeech(TTSModel): decoder_type (str, optional): decoder type. Defaults to 'residual_conv_bn'. decoder_params (dict, optional): set decoder parameters depending on 'decoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 }. num_speakers (int, optional): number of speakers for multi-speaker training. Defaults to 0. - external_c (bool, optional): enable external speaker embeddings. Defaults to False. - c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0. + use_d_vector (bool, optional): enable external speaker embeddings. Defaults to False. + d_vector_dim (int, optional): number of channels in speaker embedding vectors. Defaults to 0. """ - # pylint: disable=dangerous-default-value - - def __init__( - self, - num_chars, - out_channels, - hidden_channels, - positional_encoding=True, - length_scale=1, - encoder_type="residual_conv_bn", - encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, - decoder_type="residual_conv_bn", - decoder_params={ + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 128 + num_speakers: int = 0 + positional_encoding: bool = True + length_scale: int = 1 + encoder_type: str = "residual_conv_bn" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13, + } + ) + decoder_type: str = "residual_conv_bn" + decoder_params: dict = field( + default_factory=lambda: { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17, - }, - num_speakers=0, - external_c=False, - c_in_channels=0, - ): + } + ) + use_d_vector: bool = False + d_vector_dim: int = 0 + +class SpeedySpeech(BaseTTS): + """Speedy Speech model + https://arxiv.org/abs/2008.03802 + + Encoder -> DurationPredictor -> Decoder + + Paper abstract: + While recent neural sequence-to-sequence models have greatly improved the quality of speech + synthesis, there has not been a system capable of fast training, fast inference and high-quality audio synthesis + at the same time. We propose a student-teacher network capable of high-quality faster-than-real-time spectrogram + synthesis, with low requirements on computational resources and fast training time. We show that self-attention + layers are not necessary for generation of high quality audio. We utilize simple convolutional blocks with + residual connections in both student and teacher networks and use only a single attention layer in the teacher + model. Coupled with a MelGAN vocoder, our model's voice quality was rated significantly higher than Tacotron 2. + Our model can be efficiently trained on a single GPU and can run in real time even on a CPU. We provide both + our source code and audio samples in our GitHub repository. + + Notes: + The vanilla model is able to achieve a reasonable performance with only + ~3M model parameters and convolutional layers. + + This model requires precomputed phoneme durations to train a duration predictor. At inference + it only uses the duration predictor to compute durations and expand encoder outputs respectively. + + You can also mix and match different encoder and decoder networks beyond the paper. + + Check `SpeedySpeechArgs` for arguments. + """ + + # pylint: disable=dangerous-default-value + + def __init__(self, config: Coqpit): super().__init__() - self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale - self.emb = nn.Embedding(num_chars, hidden_channels) - self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels) - if positional_encoding: - self.pos_encoder = PositionalEncoding(hidden_channels) - self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params) - self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels) + self.config = config - if num_speakers > 1 and not external_c: + if "characters" in config: + chars, self.config = self.get_characters(config) + self.num_chars = len(chars) + + self.length_scale = ( + float(config.model_args.length_scale) + if isinstance(config.model_args.length_scale, int) + else config.model_args.length_scale + ) + self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + config.model_args.d_vector_dim, + ) + if config.model_args.positional_encoding: + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, + ) + self.duration_predictor = DurationPredictor(config.model_args.hidden_channels + config.model_args.d_vector_dim) + + if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: # speaker embedding layer - self.emb_g = nn.Embedding(num_speakers, c_in_channels) + self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) - if c_in_channels > 0 and c_in_channels != hidden_channels: - self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1) + if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) @staticmethod def expand_encoder_outputs(en, dr, x_mask, y_mask): @@ -244,7 +296,7 @@ class SpeedySpeech(TTSModel): # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) - return figures, train_audio + return figures, {"audio": train_audio} def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) @@ -260,3 +312,8 @@ class SpeedySpeech(TTSModel): if eval: self.eval() assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import SpeedySpeechLoss # pylint: disable=import-outside-toplevel + + return SpeedySpeechLoss(self.config) diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 317d1905..95b4a358 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -1,166 +1,86 @@ # coding: utf-8 + +from typing import Dict, Tuple + import torch +from coqpit import Coqpit from torch import nn from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG -from TTS.tts.models.tacotron_abstract import TacotronAbstract +from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor -class Tacotron(TacotronAbstract): +class Tacotron(BaseTacotron): """Tacotron as in https://arxiv.org/abs/1703.10135 - It's an autoregressive encoder-attention-decoder-postnet architecture. - - Args: - num_chars (int): number of input characters to define the size of embedding layer. - num_speakers (int): number of speakers in the dataset. >1 enables multi-speaker training and model learns speaker embeddings. - r (int): initial model reduction rate. - postnet_output_dim (int, optional): postnet output channels. Defaults to 80. - decoder_output_dim (int, optional): decoder output channels. Defaults to 80. - attn_type (str, optional): attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'. - attn_win (bool, optional): enable/disable attention windowing. - It especially useful at inference to keep attention alignment diagonal. Defaults to False. - attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax". - prenet_type (str, optional): prenet type for the decoder. Defaults to "original". - prenet_dropout (bool, optional): prenet dropout rate. Defaults to True. - prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for - some models. Defaults to False. - forward_attn (bool, optional): enable/disable forward attention. - It is only valid if ```attn_type``` is ```original```. Defaults to False. - trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False. - forward_attn_mask (bool, optional): enable/disable extra masking over forward attention. Defaults to False. - location_attn (bool, optional): enable/disable location sensitive attention. - It is only valid if ```attn_type``` is ```original```. Defaults to True. - attn_K (int, optional): Number of attention heads for GMM attention. Defaults to 5. - separate_stopnet (bool, optional): enable/disable separate stopnet training without only gradient - flow from stopnet to the rest of the model. Defaults to True. - bidirectional_decoder (bool, optional): enable/disable bidirectional decoding. Defaults to False. - double_decoder_consistency (bool, optional): enable/disable double decoder consistency. Defaults to False. - ddc_r (int, optional): reduction rate for the coarse decoder of double decoder consistency. Defaults to None. - encoder_in_features (int, optional): input channels for the encoder. Defaults to 512. - decoder_in_features (int, optional): input channels for the decoder. Defaults to 512. - d_vector_dim (int, optional): external speaker conditioning vector channels. Defaults to None. - use_gst (bool, optional): enable/disable Global style token module. - gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None. - memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size``` - output frames to the prenet. - gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used. - Defaults to `[]`. - max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000. + Check `TacotronConfig` for the arguments. """ - def __init__( - self, - num_chars, - num_speakers, - r=5, - postnet_output_dim=1025, - decoder_output_dim=80, - attn_type="original", - attn_win=False, - attn_norm="sigmoid", - prenet_type="original", - prenet_dropout=True, - prenet_dropout_at_inference=False, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=256, - decoder_in_features=256, - d_vector_dim=None, - use_gst=False, - gst=None, - memory_size=5, - gradual_training=None, - max_decoder_steps=500, - ): - super().__init__( - num_chars, - num_speakers, - r, - postnet_output_dim, - decoder_output_dim, - attn_type, - attn_win, - attn_norm, - prenet_type, - prenet_dropout, - prenet_dropout_at_inference, - forward_attn, - trans_agent, - forward_attn_mask, - location_attn, - attn_K, - separate_stopnet, - bidirectional_decoder, - double_decoder_consistency, - ddc_r, - encoder_in_features, - decoder_in_features, - d_vector_dim, - use_gst, - gst, - gradual_training, - ) + def __init__(self, config: Coqpit): + super().__init__(config) - # speaker embedding layers + self.num_chars, self.config = self.get_characters(config) + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # speaker embedding layer if self.num_speakers > 1: - if not self.use_d_vectors: - d_vector_dim = 256 - self.speaker_embedding = nn.Embedding(self.num_speakers, d_vector_dim) - self.speaker_embedding.weight.data.normal_(0, 0.3) + self.init_multispeaker(config) # speaker and gst embeddings is concat in decoder input if self.num_speakers > 1: - self.decoder_in_features += d_vector_dim # add speaker embedding dim + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim # embedding layer - self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) + self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0) self.embedding.weight.data.normal_(0, 0.3) # base model layers self.encoder = Encoder(self.encoder_in_features) self.decoder = Decoder( self.decoder_in_features, - decoder_output_dim, - r, - memory_size, - attn_type, - attn_win, - attn_norm, - prenet_type, - prenet_dropout, - forward_attn, - trans_agent, - forward_attn_mask, - location_attn, - attn_K, - separate_stopnet, - max_decoder_steps, + self.decoder_output_dim, + self.r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, ) - self.postnet = PostCBHG(decoder_output_dim) - self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) + self.postnet = PostCBHG(self.decoder_output_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.out_channels) # setup prenet dropout - self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference # global style token layers if self.gst and self.use_gst: self.gst_layer = GST( - num_mel=decoder_output_dim, - d_vector_dim=d_vector_dim, - num_heads=gst.gst_num_heads, - num_style_tokens=gst.gst_num_style_tokens, - gst_embedding_dim=gst.gst_embedding_dim, + num_mel=self.decoder_output_dim, + d_vector_dim=self.d_vector_dim + if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding + else None, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, ) # backward pass decoder if self.bidirectional_decoder: @@ -169,21 +89,21 @@ class Tacotron(TacotronAbstract): if self.double_decoder_consistency: self.coarse_decoder = Decoder( self.decoder_in_features, - decoder_output_dim, - ddc_r, - memory_size, - attn_type, - attn_win, - attn_norm, - prenet_type, - prenet_dropout, - forward_attn, - trans_agent, - forward_attn_mask, - location_attn, - attn_K, - separate_stopnet, - max_decoder_steps, + self.decoder_output_dim, + self.ddc_r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, ) def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None): @@ -205,7 +125,9 @@ class Tacotron(TacotronAbstract): # global style token if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, aux_input["d_vectors"]) + encoder_outputs = self.compute_gst( + encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None + ) # speaker embedding if self.num_speakers > 1: if not self.use_d_vectors: @@ -341,7 +263,7 @@ class Tacotron(TacotronAbstract): loss_dict["align_error"] = align_error return outputs, loss_dict - def train_log(self, ap, batch, outputs): + def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict) -> Tuple[Dict, Dict]: postnet_outputs = outputs["model_outputs"] alignments = outputs["alignments"] alignments_backward = outputs["alignments_backward"] @@ -362,7 +284,7 @@ class Tacotron(TacotronAbstract): # Sample audio train_audio = ap.inv_spectrogram(pred_spec.T) - return figures, train_audio + return figures, {"audio": train_audio} def eval_step(self, batch, criterion): return self.train_step(batch, criterion) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index d56bd988..eaca3ff8 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -1,160 +1,84 @@ # coding: utf-8 + +from typing import Dict, Tuple + import torch +from coqpit import Coqpit from torch import nn from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet -from TTS.tts.models.tacotron_abstract import TacotronAbstract +from TTS.tts.models.base_tacotron import BaseTacotron from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor -class Tacotron2(TacotronAbstract): +class Tacotron2(BaseTacotron): """Tacotron2 as in https://arxiv.org/abs/1712.05884 - - It's an autoregressive encoder-attention-decoder-postnet architecture. - - Args: - num_chars (int): number of input characters to define the size of embedding layer. - num_speakers (int): number of speakers in the dataset. >1 enables multi-speaker training and model learns speaker embeddings. - r (int): initial model reduction rate. - postnet_output_dim (int, optional): postnet output channels. Defaults to 80. - decoder_output_dim (int, optional): decoder output channels. Defaults to 80. - attn_type (str, optional): attention type. Check ```TTS.tts.layers.tacotron.common_layers.init_attn```. Defaults to 'original'. - attn_win (bool, optional): enable/disable attention windowing. - It especially useful at inference to keep attention alignment diagonal. Defaults to False. - attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax". - prenet_type (str, optional): prenet type for the decoder. Defaults to "original". - prenet_dropout (bool, optional): prenet dropout rate. Defaults to True. - prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for - some models. Defaults to False. - forward_attn (bool, optional): enable/disable forward attention. - It is only valid if ```attn_type``` is ```original```. Defaults to False. - trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False. - forward_attn_mask (bool, optional): enable/disable extra masking over forward attention. Defaults to False. - location_attn (bool, optional): enable/disable location sensitive attention. - It is only valid if ```attn_type``` is ```original```. Defaults to True. - attn_K (int, optional): Number of attention heads for GMM attention. Defaults to 5. - separate_stopnet (bool, optional): enable/disable separate stopnet training without only gradient - flow from stopnet to the rest of the model. Defaults to True. - bidirectional_decoder (bool, optional): enable/disable bidirectional decoding. Defaults to False. - double_decoder_consistency (bool, optional): enable/disable double decoder consistency. Defaults to False. - ddc_r (int, optional): reduction rate for the coarse decoder of double decoder consistency. Defaults to None. - encoder_in_features (int, optional): input channels for the encoder. Defaults to 512. - decoder_in_features (int, optional): input channels for the decoder. Defaults to 512. - d_vector_dim (int, optional): external speaker conditioning vector channels. Defaults to None. - use_gst (bool, optional): enable/disable Global style token module. - gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None. - gradual_training (List): Gradual training schedule. If None or `[]`, no gradual training is used. - Defaults to `[]`. - max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000. + Check `TacotronConfig` for the arguments. """ - def __init__( - self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type="original", - attn_win=False, - attn_norm="softmax", - prenet_type="original", - prenet_dropout=True, - prenet_dropout_at_inference=False, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=512, - decoder_in_features=512, - d_vector_dim=None, - use_gst=False, - gst=None, - gradual_training=None, - max_decoder_steps=500, - ): - super().__init__( - num_chars, - num_speakers, - r, - postnet_output_dim, - decoder_output_dim, - attn_type, - attn_win, - attn_norm, - prenet_type, - prenet_dropout, - prenet_dropout_at_inference, - forward_attn, - trans_agent, - forward_attn_mask, - location_attn, - attn_K, - separate_stopnet, - bidirectional_decoder, - double_decoder_consistency, - ddc_r, - encoder_in_features, - decoder_in_features, - d_vector_dim, - use_gst, - gst, - gradual_training, - ) + def __init__(self, config: Coqpit): + super().__init__(config) + + chars, self.config = self.get_characters(config) + self.num_chars = len(chars) + self.decoder_output_dim = config.out_channels + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) # speaker embedding layer if self.num_speakers > 1: - if not self.use_d_vectors: - d_vector_dim = 512 - self.speaker_embedding = nn.Embedding(self.num_speakers, d_vector_dim) - self.speaker_embedding.weight.data.normal_(0, 0.3) + self.init_multispeaker(config) # speaker and gst embeddings is concat in decoder input if self.num_speakers > 1: - self.decoder_in_features += d_vector_dim # add speaker embedding dim + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim # embedding layer - self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) + self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0) # base model layers self.encoder = Encoder(self.encoder_in_features) self.decoder = Decoder( self.decoder_in_features, self.decoder_output_dim, - r, - attn_type, - attn_win, - attn_norm, - prenet_type, - prenet_dropout, - forward_attn, - trans_agent, - forward_attn_mask, - location_attn, - attn_K, - separate_stopnet, - max_decoder_steps, + self.r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, ) - self.postnet = Postnet(self.postnet_output_dim) + self.postnet = Postnet(self.out_channels) # setup prenet dropout - self.decoder.prenet.dropout_at_g = prenet_dropout_at_inference + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference # global style token layers - if self.gst and use_gst: + if self.gst and self.use_gst: self.gst_layer = GST( - num_mel=decoder_output_dim, - d_vector_dim=d_vector_dim, - num_heads=gst.gst_num_heads, - num_style_tokens=gst.gst_num_style_tokens, - gst_embedding_dim=gst.gst_embedding_dim, + num_mel=self.decoder_output_dim, + d_vector_dim=self.d_vector_dim + if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding + else None, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, ) # backward pass decoder @@ -165,19 +89,19 @@ class Tacotron2(TacotronAbstract): self.coarse_decoder = Decoder( self.decoder_in_features, self.decoder_output_dim, - ddc_r, - attn_type, - attn_win, - attn_norm, - prenet_type, - prenet_dropout, - forward_attn, - trans_agent, - forward_attn_mask, - location_attn, - attn_K, - separate_stopnet, - max_decoder_steps, + self.ddc_r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, ) @staticmethod @@ -206,7 +130,9 @@ class Tacotron2(TacotronAbstract): encoder_outputs = self.encoder(embedded_inputs, text_lengths) if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, aux_input["d_vectors"]) + encoder_outputs = self.compute_gst( + encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None + ) if self.num_speakers > 1: if not self.use_d_vectors: # B x 1 x speaker_embed_dim @@ -342,7 +268,7 @@ class Tacotron2(TacotronAbstract): loss_dict["align_error"] = align_error return outputs, loss_dict - def train_log(self, ap, batch, outputs): + def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict) -> Tuple[Dict, Dict]: postnet_outputs = outputs["model_outputs"] alignments = outputs["alignments"] alignments_backward = outputs["alignments_backward"] @@ -363,7 +289,7 @@ class Tacotron2(TacotronAbstract): # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) - return figures, train_audio + return figures, {"audio": train_audio} def eval_step(self, batch, criterion): return self.train_step(batch, criterion) diff --git a/TTS/tts/tf/models/tacotron2.py b/TTS/tts/tf/models/tacotron2.py index 9cc62070..7a1d695d 100644 --- a/TTS/tts/tf/models/tacotron2.py +++ b/TTS/tts/tf/models/tacotron2.py @@ -12,7 +12,7 @@ class Tacotron2(keras.models.Model): num_chars, num_speakers, r, - postnet_output_dim=80, + out_channels=80, decoder_output_dim=80, attn_type="original", attn_win=False, @@ -31,7 +31,7 @@ class Tacotron2(keras.models.Model): super().__init__() self.r = r self.decoder_output_dim = decoder_output_dim - self.postnet_output_dim = postnet_output_dim + self.out_channels = out_channels self.bidirectional_decoder = bidirectional_decoder self.num_speakers = num_speakers self.speaker_embed_dim = 256 @@ -58,7 +58,7 @@ class Tacotron2(keras.models.Model): name="decoder", enable_tflite=enable_tflite, ) - self.postnet = Postnet(postnet_output_dim, 5, name="postnet") + self.postnet = Postnet(out_channels, 5, name="postnet") @tf.function(experimental_relax_shapes=True) def call(self, characters, text_lengths=None, frames=None, training=None): diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py new file mode 100644 index 00000000..f879cd42 --- /dev/null +++ b/TTS/vocoder/models/base_vocoder.py @@ -0,0 +1,20 @@ +from TTS.model import BaseModel + +# pylint: skip-file + + +class BaseVocoder(BaseModel): + """Base `vocoder` class. Every new `vocoder` model must inherit this. + + It defines `vocoder` specific functions on top of `Model`. + + Notes on input/output tensor shapes: + Any input or output tensor of the model must be shaped as + + - 3D tensors `batch x time x channels` + - 2D tensors `batch x channels` + - 1D tensors `batch x 1` + """ + + def __init__(self): + super().__init__()