mirror of https://github.com/coqui-ai/TTS.git
Rename `g` as `spk_emb`
This commit is contained in:
parent
2d29e8219d
commit
8adcd1de8e
|
@ -170,6 +170,11 @@ class ForwardTTS(BaseTTS):
|
|||
If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each
|
||||
input character as in the FastPitch model.
|
||||
|
||||
::
|
||||
|
||||
|-----> (optional) PitchPredictor(o_en, spk_emb) --> pitch_emb --> o_en = o_en + pitch_emb-----| -> CondConv(spk_emb) -> spk_proj
|
||||
spk, text -> Encoder(text, spk)--> o_en, spk_emb -----> DurationPredictor(o_en, spk_emb)--> dur -------------------------> Expand(o_en, dur) -> PositionEncoding(o_en_expand) -> Decoder(o_en_expand_pos, spk_proj) -> mel_out
|
||||
|
||||
`ForwardTTS` can be configured to one of these architectures,
|
||||
|
||||
- FastPitch
|
||||
|
@ -610,19 +615,19 @@ class ForwardTTS(BaseTTS):
|
|||
- g: :math:`[B, C]`
|
||||
- pitch: :math:`[B, 1, T]`
|
||||
"""
|
||||
g = self._set_speaker_input(aux_input)
|
||||
spk = self._set_speaker_input(aux_input)
|
||||
# compute sequence masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() # [B, 1, T_max2]
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() # [B, 1, T_max]
|
||||
# encoder pass
|
||||
x_emb, x_mask, g, o_en = self._forward_encoder(
|
||||
x, x_mask, g
|
||||
x_emb, x_mask, spk_emb, o_en = self._forward_encoder(
|
||||
x, x_mask, spk
|
||||
) # [B, T_max, C_en], [B, 1, T_max], [B, C], [B, C_en, T_max]
|
||||
# duration predictor pass
|
||||
if self.args.detach_duration_predictor:
|
||||
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=g) # [B, 1, T_max]
|
||||
o_dr_log = self.duration_predictor(x=o_en.detach(), x_mask=x_mask, g=spk_emb) # [B, 1, T_max]
|
||||
else:
|
||||
o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=g) # [B, 1, T_max]
|
||||
o_dr_log = self.duration_predictor(x=o_en, x_mask=x_mask, g=spk_emb) # [B, 1, T_max]
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
# generate attn mask from predicted durations
|
||||
dur_predictor_attn = self.generate_attn(o_dr.squeeze(1), x_mask) # [B, T_max, T_max2']
|
||||
|
@ -644,7 +649,7 @@ class ForwardTTS(BaseTTS):
|
|||
avg_pitch = None
|
||||
if self.args.use_pitch:
|
||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(
|
||||
o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=g
|
||||
o_en=o_en, x_mask=x_mask, pitch=pitch, dr=dr, g=spk_emb
|
||||
)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# expand encoder outputs
|
||||
|
@ -652,10 +657,10 @@ class ForwardTTS(BaseTTS):
|
|||
o_en=o_en, y_lengths=y_lengths, dr=dr, x_mask=x_mask
|
||||
) # [B, 1, T_max2], [B, C_en, T_max2], [B, T_max2, T_max]
|
||||
# decoder pass
|
||||
o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g) # [B, T_max2, C_de]
|
||||
o_de = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=spk_emb) # [B, T_max2, C_de]
|
||||
outputs = {
|
||||
"model_outputs": o_de, # [B, T, C]
|
||||
"g": g, # [B, C]
|
||||
"spk_emb": spk_emb, # [B, C]
|
||||
"durations_log": o_dr_log.squeeze(1), # [B, T]
|
||||
"durations": o_dr.squeeze(1), # [B, T]
|
||||
"attn_durations": dur_predictor_attn, # for visualization [B, T_en, T_de']
|
||||
|
@ -688,11 +693,11 @@ class ForwardTTS(BaseTTS):
|
|||
- x_lengths: [B]
|
||||
- g: [B, C]
|
||||
"""
|
||||
g = self._set_speaker_input(aux_input)
|
||||
spk = self._set_speaker_input(aux_input)
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
|
||||
# encoder pass
|
||||
_, x_mask, g, o_en = self._forward_encoder(x, x_mask, g)
|
||||
_, x_mask, spk_emb, o_en = self._forward_encoder(x, x_mask, spk)
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en, x_mask)
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
|
@ -700,7 +705,7 @@ class ForwardTTS(BaseTTS):
|
|||
# pitch predictor pass
|
||||
o_pitch = None
|
||||
if self.args.use_pitch:
|
||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
|
||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en=o_en, x_mask=x_mask, g=spk_emb)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# expand encoder outputs
|
||||
y_mask, o_en_ex, attn = self._expand_encoder(o_en=o_en, y_lengths=y_lengths, dr=o_dr, x_mask=x_mask)
|
||||
|
@ -708,13 +713,13 @@ class ForwardTTS(BaseTTS):
|
|||
"alignments": attn,
|
||||
"pitch": o_pitch,
|
||||
"durations": o_dr,
|
||||
"g": g,
|
||||
"spk_emb": spk_emb,
|
||||
}
|
||||
if skip_decoder:
|
||||
outputs["o_en_ex"] = o_en_ex
|
||||
else:
|
||||
# decoder pass
|
||||
outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=g)
|
||||
outputs["model_outputs"] = self._forward_decoder(o_en_ex=o_en_ex, y_mask=y_mask, g=spk_emb)
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
|
|
Loading…
Reference in New Issue