From 53f54898bcbad0a1bc9aa6bfbf1604d0a18093f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 8 Apr 2021 14:22:47 +0200 Subject: [PATCH] small fixes --- TTS/vocoder/utils/generic_utils.py | 2 +- run_tests.sh | 2 +- tests/inputs/test_align_tts.json | 1 + tests/test_vocoder_losses.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 7f4c187f..77386d30 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -96,7 +96,7 @@ def setup_generator(c): in_channels=c.audio['num_mels'], out_channels=1, **c.generator_model_params) - if c.generator_model.lower() in 'melgan_generator': + elif c.generator_model.lower() in 'melgan_generator': model = MyModel( in_channels=c.audio['num_mels'], out_channels=1, diff --git a/run_tests.sh b/run_tests.sh index 18812318..b4878550 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -4,7 +4,7 @@ TF_CPP_MIN_LOG_LEVEL=3 # # tests nosetests tests -x &&\ -# runtime tests +# # runtime tests ./tests/test_demo_server.sh && \ ./tests/test_resample.sh && \ ./tests/test_tacotron_train.sh && \ diff --git a/tests/inputs/test_align_tts.json b/tests/inputs/test_align_tts.json index 9037b535..8815906b 100644 --- a/tests/inputs/test_align_tts.json +++ b/tests/inputs/test_align_tts.json @@ -65,6 +65,7 @@ // MODEL PARAMETERS "positional_encoding": true, "hidden_channels": 256, + "hidden_channels_dp": 128, "encoder_type": "fftransformer", "encoder_params":{ "hidden_channels_ffn": 1024 , diff --git a/tests/test_vocoder_losses.py b/tests/test_vocoder_losses.py index d578a130..765f67b3 100644 --- a/tests/test_vocoder_losses.py +++ b/tests/test_vocoder_losses.py @@ -53,6 +53,7 @@ def test_multiscale_stft_loss(): assert loss_sc < 1.0 assert loss_m + loss_sc > 0 + def test_melgan_feature_loss(): feats_real = [] feats_fake = [] @@ -71,7 +72,6 @@ def test_melgan_feature_loss(): loss = loss_func(feats_fake, feats_real) assert loss.item() <= 1.0 - feats_real = [] feats_fake = []