From ab37fa9c3905f9e0bd3c0f73c423e03ab8265ff0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 10 Sep 2021 17:25:00 +0000 Subject: [PATCH] Edit AlignTTS --- TTS/tts/models/align_tts.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 0d75f482..78fbaeab 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -10,9 +10,8 @@ from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding -from TTS.tts.utils.helpers import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec @@ -168,7 +167,12 @@ class AlignTTS(BaseTTS): return dr_mas.squeeze(1), log_p @staticmethod - def convert_dr_to_align(dr, x_mask, y_mask): + def generate_attn(dr, x_mask, y_mask=None): + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) return attn @@ -187,7 +191,7 @@ class AlignTTS(BaseTTS): [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]] """ - attn = self.convert_dr_to_align(dr, x_mask, y_mask) + attn = self.generate_attn(dr, x_mask, y_mask) o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) return o_en_ex, attn @@ -275,7 +279,7 @@ class AlignTTS(BaseTTS): o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) - attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask) + attn = self.generate_attn(dr_mas, x_mask, y_mask) elif phase == 1: # train decoder o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)