diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py index 873f298e..3840d1d6 100644 --- a/TTS/tts/configs/fast_pitch_config.py +++ b/TTS/tts/configs/fast_pitch_config.py @@ -2,12 +2,12 @@ from dataclasses import dataclass, field from typing import List from TTS.tts.configs.shared_configs import BaseTTSConfig -from TTS.tts.models.fast_pitch import FastPitchArgs +from TTS.tts.models.forward_tts import ForwardTTSArgs @dataclass class FastPitchConfig(BaseTTSConfig): - """Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models. + """Configure `ForwardTTS` as FastPitch model. Example: @@ -36,22 +36,43 @@ class FastPitchConfig(BaseTTSConfig): d_vector_file (str): Path to the file including pre-computed speaker embeddings. Defaults to None. - noam_schedule (bool): - enable / disable the use of Noam LR scheduler. Defaults to False. + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. - warmup_steps (int): - Number of warm-up steps for the Noam scheduler. Defaults 4000. + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. lr (float): Initial learning rate. Defaults to `1e-3`. + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + wd (float): Weight decay coefficient. Defaults to `1e-7`. ssim_loss_alpha (float): Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. - huber_loss_alpha (float): + dur_loss_alpha (float): Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. spec_loss_alpha (float): @@ -73,9 +94,9 @@ class FastPitchConfig(BaseTTSConfig): Maximum input sequence length to be used at training. Larger values result in more VRAM usage. """ - model: str = "fast_pitch" + model: str = "forward_tts" # model specific params - model_args: FastPitchArgs = field(default_factory=FastPitchArgs) + model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs) # multi-speaker settings use_speaker_embedding: bool = False @@ -92,11 +113,13 @@ class FastPitchConfig(BaseTTSConfig): grad_clip: float = 5.0 # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True ssim_loss_alpha: float = 1.0 dur_loss_alpha: float = 1.0 spec_loss_alpha: float = 1.0 pitch_loss_alpha: float = 1.0 - dur_loss_alpha: float = 1.0 aligner_loss_alpha: float = 1.0 binary_align_loss_alpha: float = 1.0 binary_align_loss_start_step: int = 20000 diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py index b2641ab5..bdfc2a6b 100644 --- a/TTS/tts/configs/speedy_speech_config.py +++ b/TTS/tts/configs/speedy_speech_config.py @@ -2,81 +2,154 @@ from dataclasses import dataclass, field from typing import List from TTS.tts.configs.shared_configs import BaseTTSConfig -from TTS.tts.models.speedy_speech import SpeedySpeechArgs +from TTS.tts.models.forward_tts import ForwardTTSArgs @dataclass class SpeedySpeechConfig(BaseTTSConfig): - """Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models. + """Configure `ForwardTTS` as SpeedySpeech model. Example: >>> from TTS.tts.configs import SpeedySpeechConfig >>> config = SpeedySpeechConfig() - Args: + Args: model (str): - Model name used for selecting the right model at initialization. Defaults to `speedy_speech`. + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + model_args (Coqpit): - Model class arguments. Check `SpeedySpeechArgs` for more details. Defaults to `SpeedySpeechArgs()`. + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + data_dep_init_steps (int): Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses Activation Normalization that pre-computes normalization stats at the beginning and use the same values for the rest. Defaults to 10. + use_speaker_embedding (bool): enable / disable using speaker embeddings for multi-speaker models. If set True, the model is in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): Path to the file including pre-computed speaker embeddings. Defaults to None. - noam_schedule (bool): - enable / disable the use of Noam LR scheduler. Defaults to False. - warmup_steps (int): - Number of warm-up steps for the Noam scheduler. Defaults 4000. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `RAdam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + lr (float): Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `l1`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `huber`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + wd (float): Weight decay coefficient. Defaults to `1e-7`. - ssim_alpha (float): - Weight for the SSIM loss. If set <= 0, disables the SSIM loss. Defaults to 1.0. - huber_alpha (float): - Weight for the duration predictor's loss. Defaults to 1.0. - l1_alpha (float): - Weight for the L1 spectrogram loss. If set <= 0, disables the L1 loss. Defaults to 1.0. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_align_loss_start_step (int): + Start binary alignment loss after this many steps. Defaults to 20000. + min_seq_len (int): Minimum input sequence length to be used at training. + max_seq_len (int): Maximum input sequence length to be used at training. Larger values result in more VRAM usage. """ - model: str = "speedy_speech" - # model specific params - model_args: SpeedySpeechArgs = field(default_factory=SpeedySpeechArgs) + model: str = "forward_tts" + + # set model args as SpeedySpeech + model_args: ForwardTTSArgs = ForwardTTSArgs( + use_pitch=False, + encoder_type="residual_conv_bn", + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13, + }, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + out_channels=80, + hidden_channels=128, + num_speakers=0, + positional_encoding=True, + ) # multi-speaker settings use_speaker_embedding: bool = False use_d_vector_file: bool = False d_vector_file: str = False + d_vector_dim: int = 0 # optimizer parameters optimizer: str = "RAdam" optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) - lr_scheduler: str = None - lr_scheduler_params: dict = None + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) lr: float = 1e-4 grad_clip: float = 5.0 # loss params - ssim_alpha: float = 1.0 - huber_alpha: float = 1.0 - l1_alpha: float = 1.0 + spec_loss_type: str = "l1" + duration_loss_type: str = "huber" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_align_loss_start_step: int = 20000 # overrides min_seq_len: int = 13 max_seq_len: int = 200 r: int = 1 # DO NOT CHANGE + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + # testing test_sentences: List[str] = field( default_factory=lambda: [ diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py new file mode 100644 index 00000000..c4411027 --- /dev/null +++ b/TTS/tts/models/forward_tts.py @@ -0,0 +1,695 @@ +from dataclasses import dataclass, field +from typing import Dict, Tuple + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast + +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask +from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram +from TTS.utils.audio import AudioProcessor + + +@dataclass +class ForwardTTSArgs(Coqpit): + """ForwardTTS Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + use_aligner (bool): + Whether to use aligner network to learn the text to speech alignment or use pre-computed durations. + If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the + pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True. + + use_pitch (bool): + Use pitch predictor to learn the pitch. Defaults to True. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + use_d_vetor (bool): + Whether to use precomputed d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of channels of the d-vectors. Defaults to 0. + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 384 + num_speakers: int = 0 + use_aligner: bool = True + use_pitch: bool = True + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + duration_predictor_hidden_channels: int = 256 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + use_d_vector: bool = False + d_vector_dim: int = 0 + detach_duration_predictor: bool = False + max_duration: int = 75 + + +class ForwardTTS(BaseTTS): + """General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment + network and a pitch predictor. + + If the alignment network is used, the model learns the text-to-speech alignment + from the data instead of using pre-computed durations. + + If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each + input character as in the FastPitch model. + + `ForwardTTS` can be configured to one of these architectures, + + - FastPitch + - SpeedySpeech + - FastSpeech + - TODO: FastSpeech2 (requires average speech energy predictor) + + Args: + config (Coqpit): Model coqpit class. + + Examples: + >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs + >>> config = ForwardTTSArgs() + >>> model = ForwardTTS(config) + """ + + # pylint: disable=dangerous-default-value + def __init__(self, config: Coqpit): + + super().__init__() + + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + if "characters" in config: + # loading from FasrPitchConfig + _, self.config, num_chars = self.get_characters(config) + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + # loading from ForwardTTSArgs + self.config = config + self.args = config.model_args + elif isinstance(config, ForwardTTSArgs): + self.args = config + self.config = config + else: + raise ValueError("config must be either a *Config or ForwardTTSArgs") + + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_pitch = self.args.use_pitch + self.use_binary_alignment_loss = False + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.args.d_vector_dim, + ) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) + + self.decoder = Decoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + self.args.hidden_channels + self.args.d_vector_dim, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, + ) + + if self.args.use_pitch: + self.pitch_predictor = DurationPredictor( + self.args.hidden_channels + self.args.d_vector_dim, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, + ) + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), + ) + + if self.args.num_speakers > 1 and not self.args.use_d_vector: + # speaker embedding layer + self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels: + self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + + if self.args.use_aligner: + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Shapes + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + + Examples: + - encoder output: :math:`[a,b,c,d]` + - durations: :math:`[1, 3, 2, 1]` + + - expanded: :math:`[a, b, b, b, c, c, d]` + - attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]` + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + return x + g + + def _forward_encoder( + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Concat speaker embedding to the encoder output for the duration predictor. + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] + if g is not None: + g = g.unsqueeze(-1) + # [B, T, C] + x_emb = self.emb(x) + # encoder pass + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g, x_emb + + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_over_durations(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() + # encoder pass + o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner + o_alignment_dur = None + alignment_soft = None + alignment_logprob = None + alignment_mas = None + if self.use_aligner: + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) + alignment_soft = alignment_soft.transpose(1, 2) + alignment_mas = alignment_mas.transpose(1, 2) + dr = o_alignment_dur + # pitch predictor pass + o_pitch = None + avg_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) + outputs = { + "model_outputs": o_de, # [B, T, C] + "durations_log": o_dr_log.squeeze(1), # [B, T] + "durations": o_dr.squeeze(1), # [B, T] + "attn_durations": o_attn, # for visualization [B, T_en, T_de'] + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, + "alignments": attn, # [B, T_de, T_en] + "alignment_soft": alignment_soft, + "alignment_mas": alignment_mas, + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + + Shapes: + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + # pitch predictor pass + o_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) + o_en = o_en + o_pitch_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g) + outputs = { + "model_outputs": o_de, + "alignments": attn, + "pitch": o_pitch, + "durations_log": o_dr_log, + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + pitch = batch["pitch"] if self.args.use_pitch else None + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + + # forward pass + outputs = self.forward( + text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input + ) + # use aligner's output as the duration target + if self.use_aligner: + durations = outputs["o_alignment_dur"] + # use float32 in AMP + with autocast(enabled=False): + # compute loss + loss_dict = criterion( + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch_avg"] if self.use_pitch else None, + pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, + alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, + alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None, + ) + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error + + return outputs, loss_dict + + def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + if self.args.use_pitch: + pitch = batch["pitch"] + pitch_avg_expanded, _ = self.expand_encoder_outputs( + outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"] + ) + pitch = pitch[0, 0].data.cpu().numpy() + # TODO: denormalize before plotting + pitch = abs(pitch) + pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy() + pitch_figures = { + "pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False), + "pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False), + } + figures.update(pitch_figures) + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + return self.train_log(ap, batch, outputs) + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import ForwardTTSLoss # pylint: disable=import-outside-toplevel + + return ForwardTTSLoss(self.config) + + def on_train_step_start(self, trainer): + """Enable binary alignment loss when needed""" + if trainer.total_steps_done > self.config.binary_align_loss_start_step: + self.use_binary_alignment_loss = True diff --git a/tests/tts_tests/test_forward_tts.py b/tests/tts_tests/test_forward_tts.py new file mode 100644 index 00000000..9bb60f48 --- /dev/null +++ b/tests/tts_tests/test_forward_tts.py @@ -0,0 +1,149 @@ +import unittest + +import torch as T + +from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs +from TTS.tts.utils.helpers import sequence_mask + +# pylint: disable=unused-variable + + +def expand_encoder_outputs_test(): + model = ForwardTTS(ForwardTTSArgs(num_chars=10)) + + inputs = T.rand(2, 5, 57) + durations = T.randint(1, 4, (2, 57)) + + x_mask = T.ones(2, 1, 57) + y_mask = T.ones(2, 1, durations.sum(1).max()) + + expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask) + + for b in range(durations.shape[0]): + index = 0 + for idx, dur in enumerate(durations[b]): + diff = ( + expanded[b, :, index : index + dur.item()] + - inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape) + ).sum() + assert abs(diff) < 1e-6, diff + index += dur + + +def model_input_output_test(): + """Assert the output shapes of the model in different modes""" + + # VANILLA MODEL + model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=False, use_aligner=False)) + + x = T.randint(0, 10, (2, 21)) + x_lengths = T.randint(10, 22, (2,)) + x_lengths[-1] = 21 + x_mask = sequence_mask(x_lengths).unsqueeze(1).long() + durations = T.randint(1, 4, (2, 21)) + durations = durations * x_mask.squeeze(1) + y_lengths = durations.sum(1) + y_mask = sequence_mask(y_lengths).unsqueeze(1).long() + + outputs = model.forward(x, x_lengths, y_lengths, dr=durations) + + assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80) + assert outputs["durations_log"].shape == (2, 21) + assert outputs["durations"].shape == (2, 21) + assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21) + assert (outputs["x_mask"] - x_mask).sum() == 0.0 + assert (outputs["y_mask"] - y_mask).sum() == 0.0 + + assert outputs["alignment_soft"] == None + assert outputs["alignment_mas"] == None + assert outputs["alignment_logprob"] == None + assert outputs["o_alignment_dur"] == None + assert outputs["pitch_avg"] == None + assert outputs["pitch_avg_gt"] == None + + # USE PITCH + model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=True, use_aligner=False)) + + x = T.randint(0, 10, (2, 21)) + x_lengths = T.randint(10, 22, (2,)) + x_lengths[-1] = 21 + x_mask = sequence_mask(x_lengths).unsqueeze(1).long() + durations = T.randint(1, 4, (2, 21)) + durations = durations * x_mask.squeeze(1) + y_lengths = durations.sum(1) + y_mask = sequence_mask(y_lengths).unsqueeze(1).long() + pitch = T.rand(2, 1, y_lengths.max()) + + outputs = model.forward(x, x_lengths, y_lengths, dr=durations, pitch=pitch) + + assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80) + assert outputs["durations_log"].shape == (2, 21) + assert outputs["durations"].shape == (2, 21) + assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21) + assert (outputs["x_mask"] - x_mask).sum() == 0.0 + assert (outputs["y_mask"] - y_mask).sum() == 0.0 + assert outputs["pitch_avg"].shape == (2, 1, 21) + assert outputs["pitch_avg_gt"].shape == (2, 1, 21) + + assert outputs["alignment_soft"] == None + assert outputs["alignment_mas"] == None + assert outputs["alignment_logprob"] == None + assert outputs["o_alignment_dur"] == None + + # USE ALIGNER NETWORK + model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=False, use_aligner=True)) + + x = T.randint(0, 10, (2, 21)) + x_lengths = T.randint(10, 22, (2,)) + x_lengths[-1] = 21 + x_mask = sequence_mask(x_lengths).unsqueeze(1).long() + durations = T.randint(1, 4, (2, 21)) + durations = durations * x_mask.squeeze(1) + y_lengths = durations.sum(1) + y_mask = sequence_mask(y_lengths).unsqueeze(1).long() + y = T.rand(2, y_lengths.max(), 80) + + outputs = model.forward(x, x_lengths, y_lengths, dr=durations, y=y) + + assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80) + assert outputs["durations_log"].shape == (2, 21) + assert outputs["durations"].shape == (2, 21) + assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21) + assert (outputs["x_mask"] - x_mask).sum() == 0.0 + assert (outputs["y_mask"] - y_mask).sum() == 0.0 + assert outputs["alignment_soft"].shape == (2, durations.sum(1).max(), 21) + assert outputs["alignment_mas"].shape == (2, durations.sum(1).max(), 21) + assert outputs["alignment_logprob"].shape == (2, 1, durations.sum(1).max(), 21) + assert outputs["o_alignment_dur"].shape == (2, 21) + + assert outputs["pitch_avg"] == None + assert outputs["pitch_avg_gt"] == None + + # USE ALIGNER NETWORK AND PITCH + model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=True, use_aligner=True)) + + x = T.randint(0, 10, (2, 21)) + x_lengths = T.randint(10, 22, (2,)) + x_lengths[-1] = 21 + x_mask = sequence_mask(x_lengths).unsqueeze(1).long() + durations = T.randint(1, 4, (2, 21)) + durations = durations * x_mask.squeeze(1) + y_lengths = durations.sum(1) + y_mask = sequence_mask(y_lengths).unsqueeze(1).long() + y = T.rand(2, y_lengths.max(), 80) + pitch = T.rand(2, 1, y_lengths.max()) + + outputs = model.forward(x, x_lengths, y_lengths, dr=durations, pitch=pitch, y=y) + + assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80) + assert outputs["durations_log"].shape == (2, 21) + assert outputs["durations"].shape == (2, 21) + assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21) + assert (outputs["x_mask"] - x_mask).sum() == 0.0 + assert (outputs["y_mask"] - y_mask).sum() == 0.0 + assert outputs["alignment_soft"].shape == (2, durations.sum(1).max(), 21) + assert outputs["alignment_mas"].shape == (2, durations.sum(1).max(), 21) + assert outputs["alignment_logprob"].shape == (2, 1, durations.sum(1).max(), 21) + assert outputs["o_alignment_dur"].shape == (2, 21) + assert outputs["pitch_avg"].shape == (2, 1, 21) + assert outputs["pitch_avg_gt"].shape == (2, 1, 21)