mirror of https://github.com/coqui-ai/TTS.git
Refactor ForwardTTS to skip decoder
This commit is contained in:
parent
cc57c20162
commit
c3fb49bf76
|
@ -126,11 +126,23 @@ class ForwardTTSArgs(Coqpit):
|
|||
length_scale: int = 1
|
||||
encoder_type: str = "fftransformer"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||
default_factory=lambda: {
|
||||
"hidden_channels_ffn": 1024,
|
||||
"num_heads": 1,
|
||||
"num_layers": 6,
|
||||
"dropout_p": 0.1,
|
||||
"kernel_size_fft": 9,
|
||||
}
|
||||
)
|
||||
decoder_type: str = "fftransformer"
|
||||
decoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||
default_factory=lambda: {
|
||||
"hidden_channels_ffn": 1024,
|
||||
"num_heads": 1,
|
||||
"num_layers": 6,
|
||||
"dropout_p": 0.1,
|
||||
"kernel_size_fft": 9,
|
||||
}
|
||||
)
|
||||
detach_duration_predictor: bool = False
|
||||
max_duration: int = 75
|
||||
|
@ -396,31 +408,24 @@ class ForwardTTS(BaseTTS):
|
|||
# speaker conditioning
|
||||
return x_emb, x_mask, g, o_en
|
||||
|
||||
def _forward_decoder(
|
||||
self,
|
||||
o_en: torch.FloatTensor,
|
||||
dr: torch.IntTensor,
|
||||
x_mask: torch.FloatTensor,
|
||||
y_lengths: torch.IntTensor,
|
||||
g: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Decoding forward pass.
|
||||
def _expand_encoder(
|
||||
self, o_en: torch.FloatTensor, y_lengths: torch.IntTensor, dr: torch.IntTensor, x_mask: torch.FloatTensor
|
||||
):
|
||||
"""Expand encoder outputs to match the decoder.
|
||||
|
||||
1. Compute the decoder output mask
|
||||
2. Expand encoder output with the durations.
|
||||
3. Apply position encoding.
|
||||
4. Add speaker embeddings if multi-speaker mode.
|
||||
5. Run the decoder.
|
||||
|
||||
Args:
|
||||
o_en (torch.FloatTensor): Encoder output.
|
||||
y_lengths (torch.IntTensor): Output sequence lengths.
|
||||
dr (torch.IntTensor): Ground truth durations or alignment network durations.
|
||||
x_mask (torch.IntTensor): Input sequence mask.
|
||||
y_lengths (torch.IntTensor): Output sequence lengths.
|
||||
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: Decoder mask, expanded encoder outputs,
|
||||
attention map
|
||||
"""
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
# expand o_en with durations
|
||||
|
@ -428,9 +433,30 @@ class ForwardTTS(BaseTTS):
|
|||
# positional encoding
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
return y_mask, o_en_ex, attn.transpose(1, 2)
|
||||
|
||||
def _forward_decoder(
|
||||
self,
|
||||
o_en_ex: torch.FloatTensor,
|
||||
y_mask: torch.FloatTensor,
|
||||
g: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Decoding forward pass.
|
||||
|
||||
1. Run the decoder network.
|
||||
|
||||
Args:
|
||||
o_en_ex (torch.FloatTensor): Expanded encoder output.
|
||||
y_lengths (torch.IntTensor): Output sequence lengths.
|
||||
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
|
||||
"""
|
||||
|
||||
# decoder pass
|
||||
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||
return o_de.transpose(1, 2), attn.transpose(1, 2)
|
||||
return o_de.transpose(1, 2)
|
||||
|
||||
def _forward_pitch_predictor(
|
||||
self,
|
||||
|
@ -556,13 +582,15 @@ class ForwardTTS(BaseTTS):
|
|||
"""
|
||||
g = self._set_speaker_input(aux_input)
|
||||
# compute sequence masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max]
|
||||
# encoder pass
|
||||
x_emb, x_mask, g, o_en = self._forward_encoder(x, x_mask, g) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
|
||||
x_emb, x_mask, g, o_en = self._forward_encoder(
|
||||
x, x_mask, g
|
||||
) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
|
||||
# duration predictor pass
|
||||
if self.args.detach_duration_predictor:
|
||||
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
|
||||
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
|
||||
else:
|
||||
o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max]
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
|
@ -589,8 +617,12 @@ class ForwardTTS(BaseTTS):
|
|||
o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g
|
||||
)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# expand encoder outputs
|
||||
y_mask, o_en_ex, attn = self._expand_encoder(
|
||||
o_en=o_en, y_lengths=y_lengths, dr=dr, x_mask=x_mask
|
||||
) # [B, 1, T_max2], [B, C_en, T_max2], [B, T_max2, T_max]
|
||||
# decoder pass
|
||||
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) # [B, T_max2, C_de], [B, T_max2, T_max]
|
||||
o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) # [B, T_max2, C_de]
|
||||
outputs = {
|
||||
"model_outputs": o_de, # [B, T, C]
|
||||
"g": g, # [B, C]
|
||||
|
@ -606,16 +638,20 @@ class ForwardTTS(BaseTTS):
|
|||
"aligner_logprob": aligner_logprob,
|
||||
"x_mask": x_mask,
|
||||
"y_mask": y_mask,
|
||||
"o_en_ex": o_en_ex,
|
||||
}
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||
def inference(
|
||||
self, x, aux_input={"d_vectors": None, "speaker_ids": None}, skip_decoder=False
|
||||
): # pylint: disable=unused-argument
|
||||
"""Model's inference pass.
|
||||
|
||||
Args:
|
||||
x (torch.LongTensor): Input character sequence.
|
||||
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
|
||||
skip_decoder (bool): Whether to skip the decoder. Defaults to False.
|
||||
|
||||
Shapes:
|
||||
- x: [B, T_max]
|
||||
|
@ -636,15 +672,19 @@ class ForwardTTS(BaseTTS):
|
|||
if self.args.use_pitch:
|
||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# decoder pass
|
||||
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
|
||||
# expand encoder outputs
|
||||
y_mask, o_en_ex, attn = self._expand_encoder(o_en=o_en, y_lengths=y_lengths, dr=o_dr, x_mask=x_mask)
|
||||
outputs = {
|
||||
"model_outputs": o_de,
|
||||
"alignments": attn,
|
||||
"pitch": o_pitch,
|
||||
"durations_log": o_dr_log,
|
||||
"g": g,
|
||||
}
|
||||
if skip_decoder:
|
||||
outputs["o_en_ex"] = o_en_ex
|
||||
else:
|
||||
# decoder pass
|
||||
outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g)
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
|
|
Loading…
Reference in New Issue