From aec0b78aff2f5de2d178514887fc2edad76c86c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Mar 2021 17:07:15 +0100 Subject: [PATCH] duration predictor fix 2 --- TTS/tts/models/align_tts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 558a0e43..99242ad1 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -196,7 +196,6 @@ class AlignTTS(nn.Module): o_en_ex = self._sum_speaker_embedding(o_en_ex, g) # decoder pass o_de = self.decoder(o_en_ex, y_mask, g=g) - return o_de, attn.transpose(1, 2) # def _forward_mas(self, o_en, y, y_lengths, x_mask): @@ -225,8 +224,8 @@ class AlignTTS(nn.Module): g: [B, C] """ o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(x, x_mask) dr_mas, mu, log_sigma, logp_max_path = self._forward_mdn(o_en, y, y_lengths, x_mask) - o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) # TODO: compute attn once o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) dr_mas_log = torch.log(1 + dr_mas).squeeze(1) @@ -242,8 +241,8 @@ class AlignTTS(nn.Module): # pad input to prevent dropping the last word x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(x, x_mask) # duration predictor pass - o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1) o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)