from dataclasses import dataclass, field import torch import torch.nn as nn import torch.nn.functional as F from coqpit import Coqpit from TTS.tts.layers.glow_tts.monotonic_align import generate_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor # pylint: disable=dangerous-default-value def mask_from_lens(lens, max_len: int = None): if max_len is None: max_len = lens.max() ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype) mask = torch.lt(ids, lens.unsqueeze(1)) return mask class LinearNorm(torch.nn.Module): def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): super(LinearNorm, self).__init__() self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) def forward(self, x): return self.linear_layer(x) class ConvNorm(torch.nn.Module): def __init__( self, in_channels, out_channels, kernel_size=1, stride=1, padding=None, dilation=1, bias=True, w_init_gain="linear", batch_norm=False, ): super(ConvNorm, self).__init__() if padding is None: assert kernel_size % 2 == 1 padding = int(dilation * (kernel_size - 1) / 2) self.conv = torch.nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, ) self.norm = torch.nn.BatchNorm1D(out_channels) if batch_norm else None torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) def forward(self, signal): if self.norm is None: return self.conv(signal) else: return self.norm(self.conv(signal)) class ConvReLUNorm(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0): super(ConvReLUNorm, self).__init__() self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size // 2)) self.norm = torch.nn.LayerNorm(out_channels) self.dropout = torch.nn.Dropout(dropout) def forward(self, signal): out = F.relu(self.conv(signal)) out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype) return self.dropout(out) class PositionalEmbedding(nn.Module): def __init__(self, demb): super(PositionalEmbedding, self).__init__() self.demb = demb inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) self.register_buffer("inv_freq", inv_freq) def forward(self, pos_seq, bsz=None): sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0)) pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1) if bsz is not None: return pos_emb[None, :, :].expand(bsz, -1, -1) else: return pos_emb[None, :, :] class PositionwiseConvFF(nn.Module): def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False): super(PositionwiseConvFF, self).__init__() self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.CoreNet = nn.Sequential( nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)), nn.ReLU(), # nn.Dropout(dropout), # worse convergence nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)), nn.Dropout(dropout), ) self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm def forward(self, inp): return self._forward(inp) def _forward(self, inp): if self.pre_lnorm: # layer normalization + positionwise feed-forward core_out = inp.transpose(1, 2) core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype)) core_out = core_out.transpose(1, 2) # residual connection output = core_out + inp else: # positionwise feed-forward core_out = inp.transpose(1, 2) core_out = self.CoreNet(core_out) core_out = core_out.transpose(1, 2) # residual connection + layer normalization output = self.layer_norm(inp + core_out).to(inp.dtype) return output class MultiHeadAttn(nn.Module): def __init__(self, num_heads, d_model, hidden_channels_head, dropout, dropout_attn=0.1, pre_lnorm=False): super(MultiHeadAttn, self).__init__() self.num_heads = num_heads self.d_model = d_model self.hidden_channels_head = hidden_channels_head self.scale = 1 / (hidden_channels_head ** 0.5) self.pre_lnorm = pre_lnorm self.qkv_net = nn.Linear(d_model, 3 * num_heads * hidden_channels_head) self.drop = nn.Dropout(dropout) self.dropout_attn = nn.Dropout(dropout_attn) self.o_net = nn.Linear(num_heads * hidden_channels_head, d_model, bias=False) self.layer_norm = nn.LayerNorm(d_model) def forward(self, inp, attn_mask=None): return self._forward(inp, attn_mask) def _forward(self, inp, attn_mask=None): residual = inp if self.pre_lnorm: # layer normalization inp = self.layer_norm(inp) num_heads, hidden_channels_head = self.num_heads, self.hidden_channels_head head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2) head_q = head_q.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) head_k = head_k.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) head_v = head_v.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) attn_score = torch.bmm(q, k.transpose(1, 2)) attn_score.mul_(self.scale) if attn_mask is not None: attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype) attn_mask = attn_mask.repeat(num_heads, attn_mask.size(2), 1) attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf")) attn_prob = F.softmax(attn_score, dim=2) attn_prob = self.dropout_attn(attn_prob) attn_vec = torch.bmm(attn_prob, v) attn_vec = attn_vec.view(num_heads, inp.size(0), inp.size(1), hidden_channels_head) attn_vec = ( attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), num_heads * hidden_channels_head) ) # linear projection attn_out = self.o_net(attn_vec) attn_out = self.drop(attn_out) if self.pre_lnorm: # residual connection output = residual + attn_out else: # residual connection + layer normalization output = self.layer_norm(residual + attn_out) output = output.to(attn_out.dtype) return output class TransformerLayer(nn.Module): def __init__( self, num_heads, hidden_channels, hidden_channels_head, hidden_channels_ffn, kernel_size, dropout, **kwargs ): super(TransformerLayer, self).__init__() self.dec_attn = MultiHeadAttn(num_heads, hidden_channels, hidden_channels_head, dropout, **kwargs) self.pos_ff = PositionwiseConvFF( hidden_channels, hidden_channels_ffn, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm") ) def forward(self, dec_inp, mask=None): output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2)) output *= mask output = self.pos_ff(output) output *= mask return output class FFTransformer(nn.Module): def __init__( self, num_layers, num_heads, hidden_channels, hidden_channels_head, hidden_channels_ffn, kernel_size, dropout, dropout_attn, dropemb=0.0, pre_lnorm=False, ): super(FFTransformer, self).__init__() self.hidden_channels = hidden_channels self.num_heads = num_heads self.hidden_channels_head = hidden_channels_head self.pos_emb = PositionalEmbedding(self.hidden_channels) self.drop = nn.Dropout(dropemb) self.layers = nn.ModuleList() for _ in range(num_layers): self.layers.append( TransformerLayer( num_heads, hidden_channels, hidden_channels_head, hidden_channels_ffn, kernel_size, dropout, dropout_attn=dropout_attn, pre_lnorm=pre_lnorm, ) ) def forward(self, x, x_lengths, conditioning=0): mask = mask_from_lens(x_lengths).unsqueeze(2) pos_seq = torch.arange(x.size(1), device=x.device).to(x.dtype) pos_emb = self.pos_emb(pos_seq) * mask if conditioning is None: conditioning = 0 out = self.drop(x + pos_emb + conditioning) for layer in self.layers: out = layer(out, mask=mask) # out = self.drop(out) return out, mask def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None): """If target=None, then predicted durations are applied""" dtype = enc_out.dtype reps = durations.float() / pace reps = (reps + 0.5).long() dec_lens = reps.sum(dim=1) max_len = dec_lens.max() reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] reps_cumsum = reps_cumsum.to(dtype) range_ = torch.arange(max_len).to(enc_out.device)[None, :, None] mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_) mult = mult.to(dtype) en_ex = torch.matmul(mult, enc_out) if mel_max_len: en_ex = en_ex[:, :mel_max_len] dec_lens = torch.clamp_max(dec_lens, mel_max_len) return en_ex, dec_lens class TemporalPredictor(nn.Module): """Predicts a single float per each temporal location""" def __init__(self, input_size, filter_size, kernel_size, dropout, num_layers=2): super(TemporalPredictor, self).__init__() self.layers = nn.Sequential( *[ ConvReLUNorm( input_size if i == 0 else filter_size, filter_size, kernel_size=kernel_size, dropout=dropout ) for i in range(num_layers) ] ) self.fc = nn.Linear(filter_size, 1, bias=True) def forward(self, enc_out, enc_out_mask): out = enc_out * enc_out_mask out = self.layers(out.transpose(1, 2)).transpose(1, 2) out = self.fc(out) * enc_out_mask return out.squeeze(-1) @dataclass class FastPitchArgs(Coqpit): num_chars: int = 100 out_channels: int = 80 hidden_channels: int = 384 num_speakers: int = 0 duration_predictor_hidden_channels: int = 256 duration_predictor_dropout: float = 0.1 duration_predictor_kernel_size: int = 3 duration_predictor_dropout_p: float = 0.1 duration_predictor_num_layers: int = 2 pitch_predictor_hidden_channels: int = 256 pitch_predictor_dropout: float = 0.1 pitch_predictor_kernel_size: int = 3 pitch_predictor_dropout_p: float = 0.1 pitch_embedding_kernel_size: int = 3 pitch_predictor_num_layers: int = 2 positional_encoding: bool = True length_scale: int = 1 encoder_type: str = "fftransformer" encoder_params: dict = field( default_factory=lambda: { "hidden_channels_head": 64, "hidden_channels_ffn": 1536, "num_heads": 1, "num_layers": 6, "kernel_size": 3, "dropout": 0.1, "dropout_attn": 0.1, } ) decoder_type: str = "fftransformer" decoder_params: dict = field( default_factory=lambda: { "hidden_channels_head": 64, "hidden_channels_ffn": 1536, "num_heads": 1, "num_layers": 6, "kernel_size": 3, "dropout": 0.1, "dropout_attn": 0.1, } ) use_d_vector: bool = False d_vector_dim: int = 0 detach_duration_predictor: bool = False max_duration: int = 75 use_gt_duration: bool = True class FastPitch(BaseTTS): """FastPitch model. Very similart to SpeedySpeech model but with pitch prediction. Paper abstract: We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental frequency contours. The model predicts pitch contours during inference. By altering these predictions, the generated speech can be more expressive, better match the semantic of the utterance, and in the end more engaging to the listener. Uniformly increasing or decreasing pitch with FastPitch generates speech that resembles the voluntary modulation of voice. Conditioning on frequency contours improves the overall quality of synthesized speech, making it comparable to state-of-the-art. It does not introduce an overhead, and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time factor for mel-spectrogram synthesis of a typical utterance." Notes: TODO Args: config (Coqpit): Model coqpit class. Examples: >>> from TTS.tts.models.fast_pitch import FastPitch, FastPitchArgs >>> config = FastPitchArgs() >>> model = FastPitch(config) """ def __init__(self, config: Coqpit): super().__init__() if "characters" in config: # loading from FasrPitchConfig _, self.config, num_chars = self.get_characters(config) config.model_args.num_chars = num_chars args = self.config.model_args else: # loading from FastPitchArgs self.config = config args = config self.max_duration = args.max_duration self.use_gt_duration = args.use_gt_duration self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale self.encoder = FFTransformer( hidden_channels=args.hidden_channels, **args.encoder_params, ) # if n_speakers > 1: # self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim) # else: # self.speaker_emb = None # self.speaker_emb_weight = speaker_emb_weight self.emb = nn.Embedding(args.num_chars, args.hidden_channels) self.duration_predictor = TemporalPredictor( args.hidden_channels, filter_size=args.duration_predictor_hidden_channels, kernel_size=args.duration_predictor_kernel_size, dropout=args.duration_predictor_dropout_p, num_layers=args.duration_predictor_num_layers, ) self.decoder = FFTransformer(hidden_channels=args.hidden_channels, **args.decoder_params) self.pitch_predictor = TemporalPredictor( args.hidden_channels, filter_size=args.pitch_predictor_hidden_channels, kernel_size=args.pitch_predictor_kernel_size, dropout=args.pitch_predictor_dropout_p, num_layers=args.pitch_predictor_num_layers, ) self.pitch_emb = nn.Conv1d( 1, args.hidden_channels, kernel_size=args.pitch_embedding_kernel_size, padding=int((args.pitch_embedding_kernel_size - 1) / 2), ) self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True) @staticmethod def expand_encoder_outputs(en, dr, x_mask, y_mask): """Generate attention alignment map from durations and expand encoder outputs Example: encoder output: [a,b,c,d] durations: [1, 3, 2, 1] expanded: [a, b, b, b, c, c, d] attention map: [[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]] """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) o_en_ex = torch.matmul(attn.transpose(1, 2), en) return o_en_ex, attn.transpose(1, 2) def forward(self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": 0, "speaker_ids": None}): speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) # Calculate speaker embedding # if self.speaker_emb is None: # speaker_embedding = 0 # else: # speaker_embedding = self.speaker_emb(speaker).unsqueeze(1) # speaker_embedding.mul_(self.speaker_emb_weight) # character embedding embedding = self.emb(x) # Input FFT o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding) # Embedded for predictors o_en_dr, mask_en_dr = o_en, mask_en # Predict durations o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # TODO: move this to the dataset avg_pitch = average_pitch(pitch, dr) # Predict pitch o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) pitch_emb = self.pitch_emb(avg_pitch) o_en = o_en + pitch_emb.transpose(1, 2) # len_regulated, dec_lens = regulate_len(dr, o_en, self.length_scale, mel_max_len) o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # Output FFT o_de, _ = self.decoder(o_en_ex, y_lengths) o_de = self.proj(o_de) outputs = { "model_outputs": o_de, "durations_log": o_dr_log.squeeze(1), "durations": o_dr.squeeze(1), "pitch": o_pitch, "pitch_gt": avg_pitch, "alignments": attn, } return outputs @torch.no_grad() def inference(self, x, aux_input={"d_vectors": 0, "speaker_ids": None}): # pylint: disable=unused-argument speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 # input sequence should be greated than the max convolution size inference_padding = 5 if x.shape[1] < 13: inference_padding += 13 - x.shape[1] # pad input to prevent dropping the last word x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # character embedding embedding = self.emb(x) # if self.speaker_emb is None: # else: # speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker # spk_emb = self.speaker_emb(speaker).unsqueeze(1) # spk_emb.mul_(self.speaker_emb_weight) # Input FFT o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding) # Predict durations o_dr_log = self.duration_predictor(o_en, mask_en) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) o_dr = o_dr * self.length_scale # Pitch over chars o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) # if pitch_transform is not None: # if self.pitch_std[0] == 0.0: # # XXX LJSpeech-1.1 defaults # mean, std = 218.14, 67.24 # else: # mean, std = self.pitch_mean[0], self.pitch_std[0] # pitch_pred = pitch_transform(pitch_pred, mask_en.sum(dim=(1, 2)), mean, std) o_pitch_emb = self.pitch_emb(o_pitch).transpose(1, 2) o_en = o_en + o_pitch_emb y_lengths = o_dr.sum(1) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask) o_de, _ = self.decoder(o_en_ex, y_lengths) o_de = self.proj(o_de) outputs = {"model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log} return outputs def train_step(self, batch: dict, criterion: nn.Module): text_input = batch["text_input"] text_lengths = batch["text_lengths"] mel_input = batch["mel_input"] mel_lengths = batch["mel_lengths"] pitch = batch["pitch"] d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] durations = batch["durations"] aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} outputs = self.forward(text_input, text_lengths, mel_lengths, durations, pitch, aux_input) # compute loss loss_dict = criterion( outputs["model_outputs"], mel_input, mel_lengths, outputs["durations_log"], durations, outputs["pitch"], outputs["pitch_gt"], text_lengths, ) # compute duration error durations_pred = outputs["durations"] duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() loss_dict["duration_error"] = duration_error return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] mel_input = batch["mel_input"] pred_spec = model_outputs[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False), } # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) return figures, {"audio": train_audio} def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): return self.train_log(ap, batch, outputs) def load_checkpoint( self, config, checkpoint_path, eval=False ): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training def get_criterion(self): from TTS.tts.layers.losses import FastPitchLoss # pylint: disable=import-outside-toplevel return FastPitchLoss(self.config) def average_pitch(pitch, durs): durs_cums_ends = torch.cumsum(durs, dim=1).long() durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) pitch_cums = torch.nn.functional.pad(torch.cumsum(pitch, dim=2), (1, 0)) bs, l = durs_cums_ends.size() n_formants = pitch.size(1) dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) pitch_sums = (torch.gather(pitch_cums, 2, dce) - torch.gather(pitch_cums, 2, dcs)).float() pitch_nelems = (torch.gather(pitch_nonzero_cums, 2, dce) - torch.gather(pitch_nonzero_cums, 2, dcs)).float() pitch_avg = torch.where(pitch_nelems == 0.0, pitch_nelems, pitch_sums / pitch_nelems) return pitch_avg