From ecb6b0d6ad7a8d8fc28c7d55e2e29d9629528f1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 3 Mar 2021 15:33:19 +0100 Subject: [PATCH] rename GlowTtts as GlowTTS --- TTS/tts/models/glow_tts.py | 2 +- tests/test_glow_tts.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2f9b6f9b..2e01f87c 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -9,7 +9,7 @@ from TTS.tts.utils.generic_utils import sequence_mask from TTS.tts.layers.glow_tts.monotonic_align import maximum_path, generate_path -class GlowTts(nn.Module): +class GlowTTS(nn.Module): """Glow TTS models from https://arxiv.org/abs/2005.11129 Args: diff --git a/tests/test_glow_tts.py b/tests/test_glow_tts.py index 670b7b67..38de84a9 100644 --- a/tests/test_glow_tts.py +++ b/tests/test_glow_tts.py @@ -7,7 +7,7 @@ from tests import get_tests_input_path from torch import optim from TTS.tts.layers.losses import GlowTTSLoss -from TTS.tts.models.glow_tts import GlowTts +from TTS.tts.models.glow_tts import GlowTTS from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor @@ -35,14 +35,13 @@ class GlowTTSTrainTest(unittest.TestCase): input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, c.audio['num_mels'], 30).to(device) - linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device) speaker_ids = torch.randint(0, 5, (8, )).long().to(device) - criterion = criterion = GlowTTSLoss() + criterion = GlowTTSLoss() # model to train - model = GlowTts( + model = GlowTTS( num_chars=32, hidden_channels_enc=48, hidden_channels_dec=48, @@ -71,7 +70,7 @@ class GlowTTSTrainTest(unittest.TestCase): mean_only=False).to(device) # reference model to compare model weights - model_ref = GlowTts( + model_ref = GlowTTS( num_chars=32, hidden_channels_enc=48, hidden_channels_dec=48,