From 07269e639b3f0d5fe25aa7144121f969eae20154 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 Mar 2021 17:06:46 +0100 Subject: [PATCH] fix duration predictor in AlignTTS --- TTS/tts/models/align_tts.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index c2ba8bf2..558a0e43 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -2,8 +2,9 @@ import torch import math from torch import nn from TTS.tts.layers.feed_forward.decoder import Decoder -from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor -from TTS.tts.layers.feed_forward.encoder import Encoder, PositionalEncoding +from TTS.tts.layers.align_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path from TTS.tts.layers.align_tts.mdn import MDNBlock @@ -11,6 +12,7 @@ from TTS.tts.layers.align_tts.mdn import MDNBlock + class AlignTTS(nn.Module): """Speedy Speech model with Monotonic Alignment Search https://arxiv.org/abs/2008.03802 @@ -75,11 +77,9 @@ class AlignTTS(nn.Module): self.pos_encoder = PositionalEncoding(hidden_channels) self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params) - self.duration_predictor = DurationPredictor(hidden_channels + - c_in_channels) + self.duration_predictor = DurationPredictor(num_chars, hidden_channels, hidden_channels_ffn=1024, num_heads=2) self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1) - # self.wn_spec_encoder = WNSpecEncoder(out_channels, hidden_channels, c_in_channels=c_in_channels) self.mdn_block = MDNBlock(hidden_channels, 2*out_channels) if num_speakers > 1 and not external_c: