Refactor ForwardTTS to skip decoder

This commit is contained in:
Eren Gölge 2022-04-19 09:21:31 +02:00 committed by Eren G??lge
parent cc57c20162
commit c3fb49bf76
1 changed files with 65 additions and 25 deletions

View File

@ -126,11 +126,23 @@ class ForwardTTSArgs(Coqpit):
length_scale: int = 1 length_scale: int = 1
encoder_type: str = "fftransformer" encoder_type: str = "fftransformer"
encoder_params: dict = field( encoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} default_factory=lambda: {
"hidden_channels_ffn": 1024,
"num_heads": 1,
"num_layers": 6,
"dropout_p": 0.1,
"kernel_size_fft": 9,
}
) )
decoder_type: str = "fftransformer" decoder_type: str = "fftransformer"
decoder_params: dict = field( decoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} default_factory=lambda: {
"hidden_channels_ffn": 1024,
"num_heads": 1,
"num_layers": 6,
"dropout_p": 0.1,
"kernel_size_fft": 9,
}
) )
detach_duration_predictor: bool = False detach_duration_predictor: bool = False
max_duration: int = 75 max_duration: int = 75
@ -396,31 +408,24 @@ class ForwardTTS(BaseTTS):
# speaker conditioning # speaker conditioning
return x_emb, x_mask, g, o_en return x_emb, x_mask, g, o_en
def _forward_decoder( def _expand_encoder(
self, self, o_en: torch.FloatTensor, y_lengths: torch.IntTensor, dr: torch.IntTensor, x_mask: torch.FloatTensor
o_en: torch.FloatTensor, ):
dr: torch.IntTensor, """Expand encoder outputs to match the decoder.
x_mask: torch.FloatTensor,
y_lengths: torch.IntTensor,
g: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Decoding forward pass.
1. Compute the decoder output mask 1. Compute the decoder output mask
2. Expand encoder output with the durations. 2. Expand encoder output with the durations.
3. Apply position encoding. 3. Apply position encoding.
4. Add speaker embeddings if multi-speaker mode.
5. Run the decoder.
Args: Args:
o_en (torch.FloatTensor): Encoder output. o_en (torch.FloatTensor): Encoder output.
y_lengths (torch.IntTensor): Output sequence lengths.
dr (torch.IntTensor): Ground truth durations or alignment network durations. dr (torch.IntTensor): Ground truth durations or alignment network durations.
x_mask (torch.IntTensor): Input sequence mask. x_mask (torch.IntTensor): Input sequence mask.
y_lengths (torch.IntTensor): Output sequence lengths.
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
Returns: Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: Decoder mask, expanded encoder outputs,
attention map
""" """
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
# expand o_en with durations # expand o_en with durations
@ -428,9 +433,30 @@ class ForwardTTS(BaseTTS):
# positional encoding # positional encoding
if hasattr(self, "pos_encoder"): if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask) o_en_ex = self.pos_encoder(o_en_ex, y_mask)
return y_mask, o_en_ex, attn.transpose(1, 2)
def _forward_decoder(
self,
o_en_ex: torch.FloatTensor,
y_mask: torch.FloatTensor,
g: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Decoding forward pass.
1. Run the decoder network.
Args:
o_en_ex (torch.FloatTensor): Expanded encoder output.
y_lengths (torch.IntTensor): Output sequence lengths.
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
"""
# decoder pass # decoder pass
o_de = self.decoder(o_en_ex, y_mask, g=g) o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de.transpose(1, 2), attn.transpose(1, 2) return o_de.transpose(1, 2)
def _forward_pitch_predictor( def _forward_pitch_predictor(
self, self,
@ -556,13 +582,15 @@ class ForwardTTS(BaseTTS):
""" """
g = self._set_speaker_input(aux_input) g = self._set_speaker_input(aux_input)
# compute sequence masks # compute sequence masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2] y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max] x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max]
# encoder pass # encoder pass
x_emb, x_mask, g, o_en = self._forward_encoder(x, x_mask, g) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max] x_emb, x_mask, g, o_en = self._forward_encoder(
x, x_mask, g
) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
# duration predictor pass # duration predictor pass
if self.args.detach_duration_predictor: if self.args.detach_duration_predictor:
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max] o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
else: else:
o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max] o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max]
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
@ -589,8 +617,12 @@ class ForwardTTS(BaseTTS):
o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g
) )
o_en = o_en + o_pitch_emb o_en = o_en + o_pitch_emb
# expand encoder outputs
y_mask, o_en_ex, attn = self._expand_encoder(
o_en=o_en, y_lengths=y_lengths, dr=dr, x_mask=x_mask
) # [B, 1, T_max2], [B, C_en, T_max2], [B, T_max2, T_max]
# decoder pass # decoder pass
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) # [B, T_max2, C_de], [B, T_max2, T_max] o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) # [B, T_max2, C_de]
outputs = { outputs = {
"model_outputs": o_de, # [B, T, C] "model_outputs": o_de, # [B, T, C]
"g": g, # [B, C] "g": g, # [B, C]
@ -606,16 +638,20 @@ class ForwardTTS(BaseTTS):
"aligner_logprob": aligner_logprob, "aligner_logprob": aligner_logprob,
"x_mask": x_mask, "x_mask": x_mask,
"y_mask": y_mask, "y_mask": y_mask,
"o_en_ex": o_en_ex,
} }
return outputs return outputs
@torch.no_grad() @torch.no_grad()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument def inference(
self, x, aux_input={"d_vectors": None, "speaker_ids": None}, skip_decoder=False
): # pylint: disable=unused-argument
"""Model's inference pass. """Model's inference pass.
Args: Args:
x (torch.LongTensor): Input character sequence. x (torch.LongTensor): Input character sequence.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
skip_decoder (bool): Whether to skip the decoder. Defaults to False.
Shapes: Shapes:
- x: [B, T_max] - x: [B, T_max]
@ -636,15 +672,19 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch: if self.args.use_pitch:
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
o_en = o_en + o_pitch_emb o_en = o_en + o_pitch_emb
# decoder pass # expand encoder outputs
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) y_mask, o_en_ex, attn = self._expand_encoder(o_en=o_en, y_lengths=y_lengths, dr=o_dr, x_mask=x_mask)
outputs = { outputs = {
"model_outputs": o_de,
"alignments": attn, "alignments": attn,
"pitch": o_pitch, "pitch": o_pitch,
"durations_log": o_dr_log, "durations_log": o_dr_log,
"g": g, "g": g,
} }
if skip_decoder:
outputs["o_en_ex"] = o_en_ex
else:
# decoder pass
outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g)
return outputs return outputs
def train_step(self, batch: dict, criterion: nn.Module): def train_step(self, batch: dict, criterion: nn.Module):