From b7caad39e09b664101a584d90ac1b0f31a6a59e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 20 Jul 2021 14:47:12 +0200 Subject: [PATCH] Make optional to detach duration predictor input --- TTS/tts/models/fast_pitch.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 9b826c3f..b6c0e60f 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -44,6 +44,7 @@ class FastPitchArgs(Coqpit): ) use_d_vector: bool = False d_vector_dim: int = 0 + detach_duration_predictor: bool = False class FastPitch(BaseTTS): @@ -237,7 +238,10 @@ class FastPitch(BaseTTS): """ g = aux_input["d_vectors"] if "d_vectors" in aux_input else None o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) - o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + if self.config.model_args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) o_en = o_en + o_pitch_emb o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) @@ -250,6 +254,7 @@ class FastPitch(BaseTTS): } return outputs + @torch.no_grad() def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument """ Shapes: @@ -267,7 +272,7 @@ class FastPitch(BaseTTS): x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) # duration predictor pass - o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) # pitch predictor pass o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)