model_manager tests

This commit is contained in:
Eren Gölge 2021-02-15 11:29:22 +00:00
parent 77e630348e
commit dc3596dad4
2 changed files with 29 additions and 2 deletions

View File

@ -21,9 +21,12 @@ class ModelManager(object):
Args:
models_file (str): path to .model.json
"""
def __init__(self, models_file=None):
def __init__(self, models_file=None, output_prefix=None):
super().__init__()
self.output_prefix = get_user_data_dir('tts')
if output_prefix is None:
self.output_prefix = get_user_data_dir('tts')
else:
self.output_prefix = os.path.join(output_prefix, 'tts')
self.url_prefix = "https://drive.google.com/uc?id="
self.models_dict = None
if models_file is not None:
@ -57,6 +60,7 @@ class ModelManager(object):
def list_models(self):
print(" Name format: type/language/dataset/model")
models_name_list = []
for model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
@ -67,6 +71,8 @@ class ModelManager(object):
print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" >: {model_type}/{lang}/{dataset}/{model}")
models_name_list.append(f'{model_type}/{lang}/{dataset}/{model}')
return models_name_list
def download_model(self, model_name):
"""Download model files given the full model name.

View File

@ -0,0 +1,21 @@
#!/usr/bin/env python3`
import os
import shutil
import glob
import unittest
from tests import get_tests_output_path
from TTS.utils.manage import ModelManager
def test_if_all_models_available():
"""Check if all the models are downloadable."""
print(" > Checking the availability of all the models under the ModelManager.")
manager = ModelManager(output_prefix=get_tests_output_path())
model_names = manager.list_models()
for model_name in model_names:
manager.download_model(model_name)
print(f" | > OK: {model_name}")
folders = glob.glob(os.path.join(manager.output_prefix, '*'))
assert len(folders) == len(model_names)
shutil.rmtree(manager.output_prefix)