From 93a6bdfd6cad5fcf951c032a6ab3a517a7798bb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 1 Feb 2021 13:18:56 +0000 Subject: [PATCH] linter fixes and version updates for deps --- TTS/bin/train_vocoder_wavegrad.py | 2 +- hubconf.py | 15 ++++++--------- pyproject.toml | 2 +- tests/test_vocoder_gan_datasets.py | 3 ++- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index b104652d..fe5fb3d7 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -344,7 +344,7 @@ def main(args): # pylint: disable=redefined-outer-name # setup criterion criterion = torch.nn.L1Loss().cuda() - + if use_cuda: model.cuda() criterion.cuda() diff --git a/hubconf.py b/hubconf.py index 0e2e60d8..9de4f7b2 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,9 +1,6 @@ dependencies = ['torch', 'gdown'] import torch -import os -import zipfile -from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.synthesizer import Synthesizer from TTS.utils.manage import ModelManager @@ -15,7 +12,7 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder >>> synthesizer = torch.hub.load('mozilla/TTS', 'tts', source='github') >>> wavs = synthesizer.tts("This is a test! This is also a test!!") wavs - is a list of values of the synthesized speech. - + Args: model_name (str, optional): One of the model names from .model.json. Defaults to 'tts_models/en/ljspeech/tacotron2-DCA'. vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/mulitband-melgan'. @@ -23,15 +20,15 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder Returns: TTS.utils.synthesizer.Synthesizer: Synthesizer object wrapping both vocoder and tts models. - """ + """ manager = ModelManager() - + model_path, config_path = manager.download_model(model_name) vocoder_path, vocoder_config_path = manager.download_model(vocoder_name) - + # create synthesizer - synthesizer = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path) - return synthesizer + synt = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path) + return synt if __name__ == '__main__': diff --git a/pyproject.toml b/pyproject.toml index 8b8da28d..b6c632d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,2 @@ [build-system] -requires = ["setuptools", "wheel", "Cython", "numpy==1.17.5"] \ No newline at end of file +requires = ["setuptools", "wheel", "Cython", "numpy==1.17.5"] diff --git a/tests/test_vocoder_gan_datasets.py b/tests/test_vocoder_gan_datasets.py index 2a487d9a..99a25dcf 100644 --- a/tests/test_vocoder_gan_datasets.py +++ b/tests/test_vocoder_gan_datasets.py @@ -61,7 +61,8 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us mel = ap.melspectrogram(audio) # the first 2 and the last 2 frames are skipped due to the padding # differences in stft - assert (feat - mel[:, :feat1.shape[-1]])[:, 2:-2].sum() <= 0, f' [!] {(feat - mel[:, :feat1.shape[-1]])[:, 2:-2].sum()}' + max_diff = abs((feat - mel[:, :feat1.shape[-1]])[:, 2:-2]).max() + assert max_diff <= 0, f' [!] {max_diff}' count_iter += 1 # if count_iter == max_iter: