From 420901f4c23a2f9882dcb544cca9261f8376ec18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 12 Feb 2021 14:41:17 +0000 Subject: [PATCH] linter fixes --- TTS/bin/find_unique_chars.py | 5 ++--- TTS/utils/arguments.py | 6 ++---- TTS/utils/manage.py | 3 ++- hubconf.py | 11 ++++++----- tests/test_demo_server.py | 2 +- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index e6c35878..654a3ff9 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -7,16 +7,15 @@ from TTS.tts.datasets.preprocess import get_preprocessor_by_name def main(): + # pylint: disable=bad-continuation parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n''' '''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\ - ''' Example runs: python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv - ''', - formatter_class=RawTextHelpFormatter) + ''', formatter_class=RawTextHelpFormatter) parser.add_argument( '--dataset', diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 031a3140..7d8f4adf 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -6,15 +6,13 @@ import argparse import glob import os import re -import json from TTS.tts.utils.generic_utils import check_config_tts +from TTS.tts.utils.text.symbols import parse_symbols from TTS.utils.console_logger import ConsoleLogger from TTS.utils.generic_utils import create_experiment_folder, get_git_branch -from TTS.utils.io import (copy_model_files, load_config, - save_characters_to_config) +from TTS.utils.io import copy_model_files, load_config from TTS.utils.tensorboard_logger import TensorboardLogger -from TTS.tts.utils.text.symbols import parse_symbols def parse_arguments(argv): diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 97cdf2b6..bd236dda 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -122,7 +122,8 @@ class ModelManager(object): """Download files from GDrive using their file ids""" gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False) - def _download_zip_file(self, file_url, output): + @staticmethod + def _download_zip_file(file_url, output): """Download the target zip file and extract the files to a folder with the same name as the zip file.""" r = requests.get(file_url) diff --git a/hubconf.py b/hubconf.py index 7fc020b5..13549dfe 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,11 +1,11 @@ -dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak +dependencies = ['torch', 'gdown', 'pysbd', 'phonemizer', 'unidecode'] # apt install espeak-ng import torch from TTS.utils.synthesizer import Synthesizer from TTS.utils.manage import ModelManager -def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder_models/en/ljspeech/mulitband-melgan', use_cuda=False): +def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name=None, use_cuda=False): """TTS entry point for PyTorch Hub that provides a Synthesizer object to synthesize speech from a give text. Example: @@ -15,7 +15,7 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder Args: model_name (str, optional): One of the model names from .model.json. Defaults to 'tts_models/en/ljspeech/tacotron2-DCA'. - vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/mulitband-melgan'. + vocoder_name (str, optional): One of the model names from .model.json. Defaults to 'vocoder_models/en/ljspeech/multiband-melgan'. pretrained (bool, optional): [description]. Defaults to True. Returns: @@ -23,8 +23,9 @@ def tts(model_name='tts_models/en/ljspeech/tacotron2-DCA', vocoder_name='vocoder """ manager = ModelManager() - model_path, config_path = manager.download_model(model_name) - vocoder_path, vocoder_config_path = manager.download_model(vocoder_name) + model_path, config_path, model_item = manager.download_model(model_name) + vocoder_name = model_item['default_vocoder'] if vocoder_name is None else vocoder_name + vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name) # create synthesizer synt = Synthesizer(model_path, config_path, vocoder_path, vocoder_config_path, use_cuda) diff --git a/tests/test_demo_server.py b/tests/test_demo_server.py index bccff55d..1de3f558 100644 --- a/tests/test_demo_server.py +++ b/tests/test_demo_server.py @@ -21,7 +21,7 @@ class DemoServerTest(unittest.TestCase): num_chars = len(phonemes) if config.use_phonemes else len(symbols) model = setup_model(num_chars, 0, config) output_path = os.path.join(get_tests_output_path()) - save_checkpoint(model, None, 10, 10, 1, output_path) + save_checkpoint(model, None, 10, 10, 1, output_path, None) def test_in_out(self): self._create_random_model()