Add support for model_info in CLI (#1623)

* model_info

* model_info

* model_info_by_idx and name

* model_info_by_idx and name

* model_info

* Update manage.py

* fixed linter

* fixed linter

* fixed linter

* fixed linter

* fixed return style checks

* fixed linter

* fixed linter

* fixed idx always positive

* added comments

* fix parser.args check

* fix parser.args check

* Make style

Co-authored-by: Eren G??lge <egolge@coqui.ai>
This commit is contained in:
p0p4k 2022-06-21 06:28:17 +09:00 committed by GitHub
parent 8b75e8be9c
commit 71281ff1e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 11 deletions

View File

@ -39,6 +39,18 @@ If you don't specify any models, then it uses LJSpeech based English model.
$ tts --list_models $ tts --list_models
``` ```
- Query info for model info by idx:
```
$ tts --model_info_by_idx "<model_type>/<model_query_idx>"
```
- Query info for model info by full name:
```
$ tts --model_info_by_name "<model_type>/<language>/<dataset>/<model_name>"
```
- Run TTS with default models: - Run TTS with default models:
``` ```
@ -48,7 +60,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
- Run a TTS model with its default vocoder model: - Run a TTS model with its default vocoder model:
``` ```
$ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name> $ tts --text "Text for TTS" --model_name "<language>/<dataset>/<model_name>"
``` ```
- Run with specific TTS and vocoder models from the list: - Run with specific TTS and vocoder models from the list:
@ -104,6 +116,21 @@ If you don't specify any models, then it uses LJSpeech based English model.
default=False, default=False,
help="list available pre-trained TTS and vocoder models.", help="list available pre-trained TTS and vocoder models.",
) )
parser.add_argument(
"--model_info_by_idx",
type=str,
default=None,
help="model info using query format: <model_type>/<model_query_idx>",
)
parser.add_argument(
"--model_info_by_name",
type=str,
default=None,
help="model info using query format: <model_type>/<language>/<dataset>/<model_name>",
)
parser.add_argument("--text", type=str, default=None, help="Text to generate speech.") parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")
# Args for running pre-trained TTS models. # Args for running pre-trained TTS models.
@ -214,13 +241,16 @@ If you don't specify any models, then it uses LJSpeech based English model.
args = parser.parse_args() args = parser.parse_args()
# print the description if either text or list_models is not set # print the description if either text or list_models is not set
if ( check_args = [
not args.text args.text,
and not args.list_models args.list_models,
and not args.list_speaker_idxs args.list_speaker_idxs,
and not args.list_language_idxs args.list_language_idxs,
and not args.reference_wav args.reference_wav,
): args.model_info_by_idx,
args.model_info_by_name,
]
if not any(check_args):
parser.parse_args(["-h"]) parser.parse_args(["-h"])
# load model manager # load model manager
@ -236,12 +266,23 @@ If you don't specify any models, then it uses LJSpeech based English model.
encoder_path = None encoder_path = None
encoder_config_path = None encoder_config_path = None
# CASE1: list pre-trained TTS models # CASE1 #list : list pre-trained TTS models
if args.list_models: if args.list_models:
manager.list_models() manager.list_models()
sys.exit() sys.exit()
# CASE2: load pre-trained model paths # CASE2 #info : model info of pre-trained TTS models
if args.model_info_by_idx:
model_query = args.model_info_by_idx
manager.model_info_by_idx(model_query)
sys.exit()
if args.model_info_by_name:
model_query_full_name = args.model_info_by_name
manager.model_info_by_full_name(model_query_full_name)
sys.exit()
# CASE3: load pre-trained model paths
if args.model_name is not None and not args.model_path: if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name) model_path, config_path, model_item = manager.download_model(args.model_name)
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
@ -249,7 +290,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
if args.vocoder_name is not None and not args.vocoder_path: if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE3: set custom model paths # CASE4: set custom model paths
if args.model_path is not None: if args.model_path is not None:
model_path = args.model_path model_path = args.model_path
config_path = args.config_path config_path = args.config_path

View File

@ -90,6 +90,81 @@ class ModelManager(object):
models_name_list.extend(model_list) models_name_list.extend(model_list)
return models_name_list return models_name_list
def model_info_by_idx(self, model_query):
"""Print the description of the model from .models.json file using model_idx
Args:
model_query (str): <model_tye>/<model_idx>
"""
model_name_list = []
model_type, model_query_idx = model_query.split("/")
try:
model_query_idx = int(model_query_idx)
if model_query_idx <= 0:
print("> model_query_idx should be a positive integer!")
return
except:
print("> model_query_idx should be an integer!")
return
model_count = 0
if model_type in self.models_dict:
for lang in self.models_dict[model_type]:
for dataset in self.models_dict[model_type][lang]:
for model in self.models_dict[model_type][lang][dataset]:
model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
model_count += 1
else:
print(f"> model_type {model_type} does not exist in the list.")
return
if model_query_idx > model_count:
print(f"model query idx exceeds the number of available models [{model_count}] ")
else:
model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")
def model_info_by_full_name(self, model_query_name):
"""Print the description of the model from .models.json file using model_full_name
Args:
model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
"""
model_type, lang, dataset, model = model_query_name.split("/")
if model_type in self.models_dict:
if lang in self.models_dict[model_type]:
if dataset in self.models_dict[model_type][lang]:
if model in self.models_dict[model_type][lang][dataset]:
print(f"> model type : {model_type}")
print(f"> language supported : {lang}")
print(f"> dataset used : {dataset}")
print(f"> model name : {model}")
if "description" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
)
else:
print("> description : coming soon")
if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
print(
f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
)
else:
print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
else:
print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
else:
print(f"> lang {lang} does not exist for {model_type}.")
else:
print(f"> model_type {model_type} does not exist in the list.")
def list_tts_models(self): def list_tts_models(self):
"""Print all `TTS` models and return a list of model names """Print all `TTS` models and return a list of model names