Edit AlignTTS

This commit is contained in:
Eren Gölge 2021-09-10 17:25:00 +00:00
parent 66732025e1
commit ab37fa9c39
1 changed files with 9 additions and 5 deletions

View File

@ -10,9 +10,8 @@ from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.utils.helpers import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_fsspec
@ -168,7 +167,12 @@ class AlignTTS(BaseTTS):
return dr_mas.squeeze(1), log_p
@staticmethod
def convert_dr_to_align(dr, x_mask, y_mask):
def generate_attn(dr, x_mask, y_mask=None):
# compute decode mask from the durations
if y_mask is None:
y_lengths = dr.sum(1).long()
y_lengths[y_lengths < 1] = 1
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn
@ -187,7 +191,7 @@ class AlignTTS(BaseTTS):
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
"""
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
attn = self.generate_attn(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
@ -275,7 +279,7 @@ class AlignTTS(BaseTTS):
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
attn = self.generate_attn(dr_mas, x_mask, y_mask)
elif phase == 1:
# train decoder
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)