mirror of https://github.com/coqui-ai/TTS.git
model_manager tests
This commit is contained in:
parent
77e630348e
commit
dc3596dad4
|
@ -21,9 +21,12 @@ class ModelManager(object):
|
|||
Args:
|
||||
models_file (str): path to .model.json
|
||||
"""
|
||||
def __init__(self, models_file=None):
|
||||
def __init__(self, models_file=None, output_prefix=None):
|
||||
super().__init__()
|
||||
self.output_prefix = get_user_data_dir('tts')
|
||||
if output_prefix is None:
|
||||
self.output_prefix = get_user_data_dir('tts')
|
||||
else:
|
||||
self.output_prefix = os.path.join(output_prefix, 'tts')
|
||||
self.url_prefix = "https://drive.google.com/uc?id="
|
||||
self.models_dict = None
|
||||
if models_file is not None:
|
||||
|
@ -57,6 +60,7 @@ class ModelManager(object):
|
|||
|
||||
def list_models(self):
|
||||
print(" Name format: type/language/dataset/model")
|
||||
models_name_list = []
|
||||
for model_type in self.models_dict:
|
||||
for lang in self.models_dict[model_type]:
|
||||
for dataset in self.models_dict[model_type][lang]:
|
||||
|
@ -67,6 +71,8 @@ class ModelManager(object):
|
|||
print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
else:
|
||||
print(f" >: {model_type}/{lang}/{dataset}/{model}")
|
||||
models_name_list.append(f'{model_type}/{lang}/{dataset}/{model}')
|
||||
return models_name_list
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env python3`
|
||||
import os
|
||||
import shutil
|
||||
import glob
|
||||
import unittest
|
||||
from tests import get_tests_output_path
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
def test_if_all_models_available():
|
||||
"""Check if all the models are downloadable."""
|
||||
print(" > Checking the availability of all the models under the ModelManager.")
|
||||
manager = ModelManager(output_prefix=get_tests_output_path())
|
||||
model_names = manager.list_models()
|
||||
for model_name in model_names:
|
||||
manager.download_model(model_name)
|
||||
print(f" | > OK: {model_name}")
|
||||
|
||||
folders = glob.glob(os.path.join(manager.output_prefix, '*'))
|
||||
assert len(folders) == len(model_names)
|
||||
shutil.rmtree(manager.output_prefix)
|
Loading…
Reference in New Issue