From dc3596dad4322f8e8e797ea9299807ad0914d3f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 15 Feb 2021 11:29:22 +0000 Subject: [PATCH] model_manager tests --- TTS/utils/manage.py | 10 ++++++++-- tests/test_model_manager.py | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 tests/test_model_manager.py diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 97cdf2b6..02e515f3 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -21,9 +21,12 @@ class ModelManager(object): Args: models_file (str): path to .model.json """ - def __init__(self, models_file=None): + def __init__(self, models_file=None, output_prefix=None): super().__init__() - self.output_prefix = get_user_data_dir('tts') + if output_prefix is None: + self.output_prefix = get_user_data_dir('tts') + else: + self.output_prefix = os.path.join(output_prefix, 'tts') self.url_prefix = "https://drive.google.com/uc?id=" self.models_dict = None if models_file is not None: @@ -57,6 +60,7 @@ class ModelManager(object): def list_models(self): print(" Name format: type/language/dataset/model") + models_name_list = [] for model_type in self.models_dict: for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: @@ -67,6 +71,8 @@ class ModelManager(object): print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]") else: print(f" >: {model_type}/{lang}/{dataset}/{model}") + models_name_list.append(f'{model_type}/{lang}/{dataset}/{model}') + return models_name_list def download_model(self, model_name): """Download model files given the full model name. diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py new file mode 100644 index 00000000..ae0a62b8 --- /dev/null +++ b/tests/test_model_manager.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3` +import os +import shutil +import glob +import unittest +from tests import get_tests_output_path +from TTS.utils.manage import ModelManager + + +def test_if_all_models_available(): + """Check if all the models are downloadable.""" + print(" > Checking the availability of all the models under the ModelManager.") + manager = ModelManager(output_prefix=get_tests_output_path()) + model_names = manager.list_models() + for model_name in model_names: + manager.download_model(model_name) + print(f" | > OK: {model_name}") + + folders = glob.glob(os.path.join(manager.output_prefix, '*')) + assert len(folders) == len(model_names) + shutil.rmtree(manager.output_prefix) \ No newline at end of file