From c3fb49bf7642031789bfcd28df3ecc5f5c5f3825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 19 Apr 2022 09:21:31 +0200 Subject: [PATCH] Refactor ForwardTTS to skip decoder --- TTS/tts/models/forward_tts.py | 90 +++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 25 deletions(-) diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 46e1812d..e5c275c7 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -126,11 +126,23 @@ class ForwardTTSArgs(Coqpit): length_scale: int = 1 encoder_type: str = "fftransformer" encoder_params: dict = field( - default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + default_factory=lambda: { + "hidden_channels_ffn": 1024, + "num_heads": 1, + "num_layers": 6, + "dropout_p": 0.1, + "kernel_size_fft": 9, + } ) decoder_type: str = "fftransformer" decoder_params: dict = field( - default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + default_factory=lambda: { + "hidden_channels_ffn": 1024, + "num_heads": 1, + "num_layers": 6, + "dropout_p": 0.1, + "kernel_size_fft": 9, + } ) detach_duration_predictor: bool = False max_duration: int = 75 @@ -396,31 +408,24 @@ class ForwardTTS(BaseTTS): # speaker conditioning return x_emb, x_mask, g, o_en - def _forward_decoder( - self, - o_en: torch.FloatTensor, - dr: torch.IntTensor, - x_mask: torch.FloatTensor, - y_lengths: torch.IntTensor, - g: torch.FloatTensor, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - """Decoding forward pass. + def _expand_encoder( + self, o_en: torch.FloatTensor, y_lengths: torch.IntTensor, dr: torch.IntTensor, x_mask: torch.FloatTensor + ): + """Expand encoder outputs to match the decoder. 1. Compute the decoder output mask 2. Expand encoder output with the durations. 3. Apply position encoding. - 4. Add speaker embeddings if multi-speaker mode. - 5. Run the decoder. Args: o_en (torch.FloatTensor): Encoder output. + y_lengths (torch.IntTensor): Output sequence lengths. dr (torch.IntTensor): Ground truth durations or alignment network durations. x_mask (torch.IntTensor): Input sequence mask. - y_lengths (torch.IntTensor): Output sequence lengths. - g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. Returns: - Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: Decoder mask, expanded encoder outputs, + attention map """ y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) # expand o_en with durations @@ -428,9 +433,30 @@ class ForwardTTS(BaseTTS): # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) + return y_mask, o_en_ex, attn.transpose(1, 2) + + def _forward_decoder( + self, + o_en_ex: torch.FloatTensor, + y_mask: torch.FloatTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Run the decoder network. + + Args: + o_en_ex (torch.FloatTensor): Expanded encoder output. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + # decoder pass o_de = self.decoder(o_en_ex, y_mask, g=g) - return o_de.transpose(1, 2), attn.transpose(1, 2) + return o_de.transpose(1, 2) def _forward_pitch_predictor( self, @@ -556,13 +582,15 @@ class ForwardTTS(BaseTTS): """ g = self._set_speaker_input(aux_input) # compute sequence masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2] + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max] # encoder pass - x_emb, x_mask, g, o_en = self._forward_encoder(x, x_mask, g) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max] + x_emb, x_mask, g, o_en = self._forward_encoder( + x, x_mask, g + ) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max] # duration predictor pass if self.args.detach_duration_predictor: - o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max] + o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max] else: o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max] o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) @@ -589,8 +617,12 @@ class ForwardTTS(BaseTTS): o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g ) o_en = o_en + o_pitch_emb + # expand encoder outputs + y_mask, o_en_ex, attn = self._expand_encoder( + o_en=o_en, y_lengths=y_lengths, dr=dr, x_mask=x_mask + ) # [B, 1, T_max2], [B, C_en, T_max2], [B, T_max2, T_max] # decoder pass - o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) # [B, T_max2, C_de], [B, T_max2, T_max] + o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) # [B, T_max2, C_de] outputs = { "model_outputs": o_de, # [B, T, C] "g": g, # [B, C] @@ -606,16 +638,20 @@ class ForwardTTS(BaseTTS): "aligner_logprob": aligner_logprob, "x_mask": x_mask, "y_mask": y_mask, + "o_en_ex": o_en_ex, } return outputs @torch.no_grad() - def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + def inference( + self, x, aux_input={"d_vectors": None, "speaker_ids": None}, skip_decoder=False + ): # pylint: disable=unused-argument """Model's inference pass. Args: x (torch.LongTensor): Input character sequence. aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + skip_decoder (bool): Whether to skip the decoder. Defaults to False. Shapes: - x: [B, T_max] @@ -636,15 +672,19 @@ class ForwardTTS(BaseTTS): if self.args.use_pitch: o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) o_en = o_en + o_pitch_emb - # decoder pass - o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) + # expand encoder outputs + y_mask, o_en_ex, attn = self._expand_encoder(o_en=o_en, y_lengths=y_lengths, dr=o_dr, x_mask=x_mask) outputs = { - "model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log, "g": g, } + if skip_decoder: + outputs["o_en_ex"] = o_en_ex + else: + # decoder pass + outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) return outputs def train_step(self, batch: dict, criterion: nn.Module):