diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py new file mode 100644 index 00000000..88bbd192 --- /dev/null +++ b/TTS/tts/configs/fast_pitch_config.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.fast_pitch import FastPitchArgs + + +@dataclass +class FastPitchConfig(BaseTTSConfig): + """Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models. + + Example: + + >>> from TTS.tts.configs import FastPitchConfig + >>> config = FastPitchConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + huber_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_pitch" + # model specific params + model_args: FastPitchArgs = field(default_factory=FastPitchArgs) + + # multi-speaker settings + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py new file mode 100644 index 00000000..9b826c3f --- /dev/null +++ b/TTS/tts/models/fast_pitch.py @@ -0,0 +1,377 @@ +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.monotonic_align import generate_path +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor + + +@dataclass +class FastPitchArgs(Coqpit): + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 256 + num_speakers: int = 0 + duration_predictor_hidden_channels: int = 256 + duration_predictor_dropout: float = 0.1 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_dropout: float = 0.1 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + positional_encoding: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + use_d_vector: bool = False + d_vector_dim: int = 0 + + +class FastPitch(BaseTTS): + """FastPitch model. Very similart to SpeedySpeech model but with pitch prediction. + + Paper abstract: + We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental + frequency contours. The model predicts pitch contours during inference. By altering these predictions, + the generated speech can be more expressive, better match the semantic of the utterance, and in the end + more engaging to the listener. Uniformly increasing or decreasing pitch with FastPitch generates speech + that resembles the voluntary modulation of voice. Conditioning on frequency contours improves the overall + quality of synthesized speech, making it comparable to state-of-the-art. It does not introduce an overhead, + and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time + factor for mel-spectrogram synthesis of a typical utterance." + + Notes: + TODO + + Args: + config (Coqpit): Model coqpit class. + + Examples: + >>> from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs + >>> config = FastPitchArgs() + >>> model = FastPitch(config) + """ + + # pylint: disable=dangerous-default-value + def __init__(self, config: Coqpit): + + super().__init__() + + _, self.config, num_chars = self.get_characters(config) + config.model_args.num_chars = num_chars + + self.length_scale = ( + float(config.model_args.length_scale) + if isinstance(config.model_args.length_scale, int) + else config.model_args.length_scale + ) + + self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) + + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + config.model_args.d_vector_dim, + ) + + if config.model_args.positional_encoding: + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + config.model_args.hidden_channels + config.model_args.d_vector_dim, + config.model_args.duration_predictor_hidden_channels, + config.model_args.duration_predictor_kernel_size, + config.model_args.duration_predictor_dropout_p, + ) + + self.pitch_predictor = DurationPredictor( + config.model_args.hidden_channels + config.model_args.d_vector_dim, + config.model_args.pitch_predictor_hidden_channels, + config.model_args.pitch_predictor_kernel_size, + config.model_args.pitch_predictor_dropout_p, + ) + + self.pitch_emb = nn.Conv1d( + 1, + config.model_args.hidden_channels, + kernel_size=config.model_args.pitch_embedding_kernel_size, + padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2), + ) + + self.register_buffer("pitch_mean", torch.zeros(1)) + self.register_buffer("pitch_std", torch.zeros(1)) + + if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: + # speaker embedding layer + self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) + + @staticmethod + def expand_encoder_outputs(en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Example: + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + return x + g + + def _forward_encoder(self, x, x_lengths, g=None): + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] + + if g is not None: + g = g.unsqueeze(-1) + + # [B, T, C] + x_emb = self.emb(x) + # [B, C, T] + x_emb = torch.transpose(x_emb, 1, -1) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + + # encoder pass + o_en = self.encoder(x_emb, x_mask) + + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g + + def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de, attn.transpose(1, 2) + + def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None): + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_pitch(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def forward( + self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=unused-argument + """ + Shapes: + x: :math:`[B, T_max]` + x_lengths: :math:`[B]` + y_lengths: :math:`[B]` + dr: :math:`[B, T_max]` + g: :math:`[B, C]` + pitch: :math:`[B, 1, T]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) + outputs = { + "model_outputs": o_de.transpose(1, 2), + "durations_log": o_dr_log.squeeze(1), + "pitch": o_pitch, + "pitch_gt": avg_pitch, + "alignments": attn, + } + return outputs + + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """ + Shapes: + x: [B, T_max] + x_lengths: [B] + g: [B, C] + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + # input sequence should be greated than the max convolution size + inference_padding = 5 + if x.shape[1] < 13: + inference_padding += 13 - x.shape[1] + # pad input to prevent dropping the last word + x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + # pitch predictor pass + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) + # if pitch_transform is not None: + # if self.pitch_std[0] == 0.0: + # # XXX LJSpeech-1.1 defaults + # mean, std = 218.14, 67.24 + # else: + # mean, std = self.pitch_mean[0], self.pitch_std[0] + # pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std) + + # if pitch_tgt is None: + # pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) + # else: + # pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) + o_en = o_en + o_pitch_emb + y_lengths = o_dr.sum(1) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "pitch": o_pitch, "durations_log": None} + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + pitch = batch["pitch"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + outputs = self.forward(text_input, text_lengths, mel_lengths, durations, pitch, aux_input) + + # compute loss + loss_dict = criterion( + outputs["model_outputs"], + mel_input, + mel_lengths, + outputs["durations_log"], + torch.log(1 + durations), + outputs["pitch"], + outputs["pitch_gt"], + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + return self.train_log(ap, batch, outputs) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import FastPitchLoss # pylint: disable=import-outside-toplevel + + return FastPitchLoss(self.config) + + +def average_pitch(pitch, durs): + durs_cums_ends = torch.cumsum(durs, dim=1).long() + durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) + pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) + pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0)) + + bs, l = durs_cums_ends.size() + n_formants = pitch.size(1) + dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) + dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) + + pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float() + pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float() + + pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems) + return pitch_avg