From f7695951128bdcf52a235610b840390ec3d91da3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 20 Dec 2021 11:53:44 +0000 Subject: [PATCH] Add more listing options to ModelManager --- TTS/utils/manage.py | 72 ++++++++++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 21 deletions(-) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index cfbbdff0..d1dedbe0 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -46,36 +46,66 @@ class ModelManager(object): with open(file_path, "r", encoding="utf-8") as json_file: self.models_dict = json.load(json_file) - def list_langs(self): - print(" Name format: type/language") - for model_type in self.models_dict: - for lang in self.models_dict[model_type]: - print(f" >: {model_type}/{lang} ") + def _list_models(self, model_type, model_count=0): + model_list = [] + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + output_path = os.path.join(self.output_prefix, model_full_name) + if os.path.exists(output_path): + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") + else: + print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") + model_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + return model_list - def list_datasets(self): - print(" Name format: type/language/dataset") - for model_type in self.models_dict: - for lang in self.models_dict[model_type]: - for dataset in self.models_dict[model_type][lang]: - print(f" >: {model_type}/{lang}/{dataset}") + def _list_for_model_type(self, model_type): + print(" Name format: language/dataset/model") + models_name_list = [] + model_count = 1 + model_type = "tts_models" + models_name_list.extend(self._list_models(model_type, model_count)) + return [name.replace(model_type + "/", "") for name in models_name_list] def list_models(self): print(" Name format: type/language/dataset/model") models_name_list = [] model_count = 1 + for model_type in self.models_dict: + model_list = self._list_models(model_type, model_count) + models_name_list.extend(model_list) + return models_name_list + + def list_tts_models(self): + """Print all `TTS` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("tts_models") + + def list_vocoder_models(self): + """Print all the `vocoder` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("vocoder_models") + + def list_langs(self): + """Print all the available languages""" + print(" Name format: type/language") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + print(f" >: {model_type}/{lang} ") + + def list_datasets(self): + """Print all the datasets""" + print(" Name format: type/language/dataset") for model_type in self.models_dict: for lang in self.models_dict[model_type]: for dataset in self.models_dict[model_type][lang]: - for model in self.models_dict[model_type][lang][dataset]: - model_full_name = f"{model_type}--{lang}--{dataset}--{model}" - output_path = os.path.join(self.output_prefix, model_full_name) - if os.path.exists(output_path): - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]") - else: - print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}") - models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") - model_count += 1 - return models_name_list + print(f" >: {model_type}/{lang}/{dataset}") def download_model(self, model_name): """Download model files given the full model name.