From a89eb12acab8d0ce7046cdcba38d8ab51fd9eed6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 10 Sep 2021 08:29:51 +0000 Subject: [PATCH] Fix glow_tts imports --- TTS/tts/models/glow_tts.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2f52b363..e643c69f 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -7,9 +7,8 @@ from torch.nn import functional as F from TTS.tts.configs import GlowTTSConfig from TTS.tts.layers.glow_tts.decoder import Decoder from TTS.tts.layers.glow_tts.encoder import Encoder -from TTS.tts.utils.helpers import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -133,7 +132,7 @@ class GlowTTS(BaseTTS): return y_mean, y_log_scale, o_attn_dur def forward( - self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} + self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value """ Shapes: @@ -185,7 +184,7 @@ class GlowTTS(BaseTTS): @torch.no_grad() def inference_with_MAS( - self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} + self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value """ It's similar to the teacher forcing in Tacotron. @@ -246,7 +245,7 @@ class GlowTTS(BaseTTS): @torch.no_grad() def decoder_inference( - self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None} + self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value """ Shapes: @@ -278,7 +277,9 @@ class GlowTTS(BaseTTS): return outputs @torch.no_grad() - def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value x_lengths = aux_input["x_lengths"] g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None @@ -331,7 +332,13 @@ class GlowTTS(BaseTTS): d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] - outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids}) + outputs = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) loss_dict = criterion( outputs["model_outputs"],