Refactor ForwardTTS to skip decoder

This commit is contained in:
Eren Gölge 2022-04-19 09:21:31 +02:00 committed by Eren G??lge
parent cc57c20162
commit c3fb49bf76
1 changed files with 65 additions and 25 deletions

View File

@ -126,11 +126,23 @@ class ForwardTTSArgs(Coqpit):
length_scale: int = 1
encoder_type: str = "fftransformer"
encoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
default_factory=lambda: {
"hidden_channels_ffn": 1024,
"num_heads": 1,
"num_layers": 6,
"dropout_p": 0.1,
"kernel_size_fft": 9,
}
)
decoder_type: str = "fftransformer"
decoder_params: dict = field(
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
default_factory=lambda: {
"hidden_channels_ffn": 1024,
"num_heads": 1,
"num_layers": 6,
"dropout_p": 0.1,
"kernel_size_fft": 9,
}
)
detach_duration_predictor: bool = False
max_duration: int = 75
@ -396,31 +408,24 @@ class ForwardTTS(BaseTTS):
# speaker conditioning
return x_emb, x_mask, g, o_en
def _forward_decoder(
self,
o_en: torch.FloatTensor,
dr: torch.IntTensor,
x_mask: torch.FloatTensor,
y_lengths: torch.IntTensor,
g: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Decoding forward pass.
def _expand_encoder(
self, o_en: torch.FloatTensor, y_lengths: torch.IntTensor, dr: torch.IntTensor, x_mask: torch.FloatTensor
):
"""Expand encoder outputs to match the decoder.
1. Compute the decoder output mask
2. Expand encoder output with the durations.
3. Apply position encoding.
4. Add speaker embeddings if multi-speaker mode.
5. Run the decoder.
Args:
o_en (torch.FloatTensor): Encoder output.
y_lengths (torch.IntTensor): Output sequence lengths.
dr (torch.IntTensor): Ground truth durations or alignment network durations.
x_mask (torch.IntTensor): Input sequence mask.
y_lengths (torch.IntTensor): Output sequence lengths.
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: Decoder mask, expanded encoder outputs,
attention map
"""
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
# expand o_en with durations
@ -428,9 +433,30 @@ class ForwardTTS(BaseTTS):
# positional encoding
if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
return y_mask, o_en_ex, attn.transpose(1, 2)
def _forward_decoder(
self,
o_en_ex: torch.FloatTensor,
y_mask: torch.FloatTensor,
g: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Decoding forward pass.
1. Run the decoder network.
Args:
o_en_ex (torch.FloatTensor): Expanded encoder output.
y_lengths (torch.IntTensor): Output sequence lengths.
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
"""
# decoder pass
o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de.transpose(1, 2), attn.transpose(1, 2)
return o_de.transpose(1, 2)
def _forward_pitch_predictor(
self,
@ -556,13 +582,15 @@ class ForwardTTS(BaseTTS):
"""
g = self._set_speaker_input(aux_input)
# compute sequence masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max]
# encoder pass
x_emb, x_mask, g, o_en = self._forward_encoder(x, x_mask, g) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
x_emb, x_mask, g, o_en = self._forward_encoder(
x, x_mask, g
) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
# duration predictor pass
if self.args.detach_duration_predictor:
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
else:
o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max]
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
@ -589,8 +617,12 @@ class ForwardTTS(BaseTTS):
o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g
)
o_en = o_en + o_pitch_emb
# expand encoder outputs
y_mask, o_en_ex, attn = self._expand_encoder(
o_en=o_en, y_lengths=y_lengths, dr=dr, x_mask=x_mask
) # [B, 1, T_max2], [B, C_en, T_max2], [B, T_max2, T_max]
# decoder pass
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) # [B, T_max2, C_de], [B, T_max2, T_max]
o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) # [B, T_max2, C_de]
outputs = {
"model_outputs": o_de, # [B, T, C]
"g": g, # [B, C]
@ -606,16 +638,20 @@ class ForwardTTS(BaseTTS):
"aligner_logprob": aligner_logprob,
"x_mask": x_mask,
"y_mask": y_mask,
"o_en_ex": o_en_ex,
}
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
def inference(
self, x, aux_input={"d_vectors": None, "speaker_ids": None}, skip_decoder=False
): # pylint: disable=unused-argument
"""Model's inference pass.
Args:
x (torch.LongTensor): Input character sequence.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
skip_decoder (bool): Whether to skip the decoder. Defaults to False.
Shapes:
- x: [B, T_max]
@ -636,15 +672,19 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch:
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
o_en = o_en + o_pitch_emb
# decoder pass
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
# expand encoder outputs
y_mask, o_en_ex, attn = self._expand_encoder(o_en=o_en, y_lengths=y_lengths, dr=o_dr, x_mask=x_mask)
outputs = {
"model_outputs": o_de,
"alignments": attn,
"pitch": o_pitch,
"durations_log": o_dr_log,
"g": g,
}
if skip_decoder:
outputs["o_en_ex"] = o_en_ex
else:
# decoder pass
outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g)
return outputs
def train_step(self, batch: dict, criterion: nn.Module):