mirror of https://github.com/coqui-ai/TTS.git
Make optional to detach duration predictor input
This commit is contained in:
parent
9af42f7886
commit
b7caad39e0
|
@ -44,6 +44,7 @@ class FastPitchArgs(Coqpit):
|
||||||
)
|
)
|
||||||
use_d_vector: bool = False
|
use_d_vector: bool = False
|
||||||
d_vector_dim: int = 0
|
d_vector_dim: int = 0
|
||||||
|
detach_duration_predictor: bool = False
|
||||||
|
|
||||||
|
|
||||||
class FastPitch(BaseTTS):
|
class FastPitch(BaseTTS):
|
||||||
|
@ -237,7 +238,10 @@ class FastPitch(BaseTTS):
|
||||||
"""
|
"""
|
||||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
if self.config.model_args.detach_duration_predictor:
|
||||||
|
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||||
|
else:
|
||||||
|
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
|
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
|
||||||
o_en = o_en + o_pitch_emb
|
o_en = o_en + o_pitch_emb
|
||||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
|
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
|
||||||
|
@ -250,6 +254,7 @@ class FastPitch(BaseTTS):
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -267,7 +272,7 @@ class FastPitch(BaseTTS):
|
||||||
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
|
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
# duration predictor pass
|
# duration predictor pass
|
||||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
# pitch predictor pass
|
# pitch predictor pass
|
||||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
|
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
|
||||||
|
|
Loading…
Reference in New Issue