From 4dbe7ed0de30146e6a10ab28ecda0b0fe48f6a4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 1 Oct 2021 09:20:07 +0000 Subject: [PATCH] Fix all-zero duration case for GlowTTS --- TTS/tts/models/glow_tts.py | 2 +- tests/tts_tests/test_glow_tts_train.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index e5c62b0e..e3a5ff3c 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -310,7 +310,7 @@ class GlowTTS(BaseTTS): o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # compute output durations w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale - w_ceil = torch.ceil(w) + w_ceil = torch.clamp_min(torch.ceil(w), 1) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None # compute masks diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests/test_glow_tts_train.py index 24c5c4cf..7da4fd33 100644 --- a/tests/tts_tests/test_glow_tts_train.py +++ b/tests/tts_tests/test_glow_tts_train.py @@ -10,7 +10,7 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs") config = GlowTTSConfig( - batch_size=8, + batch_size=2, eval_batch_size=8, num_loader_workers=0, num_eval_loader_workers=0, @@ -27,6 +27,7 @@ config = GlowTTSConfig( test_sentences=[ "Be a voice, not an echo.", ], + data_dep_init_steps=1.0, ) config.audio.do_trim_silence = True config.audio.trim_db = 60