mirror of https://github.com/coqui-ai/TTS.git
Print model license while downloading
This commit is contained in:
parent
c18100d112
commit
5cc8e48c3a
|
@ -4,7 +4,7 @@ import os
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile, rmtree
|
from shutil import copyfile, rmtree
|
||||||
from typing import Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
@ -12,6 +12,16 @@ from TTS.config import load_config
|
||||||
from TTS.utils.generic_utils import get_user_data_dir
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
|
|
||||||
|
|
||||||
|
LICENSE_URLS = {
|
||||||
|
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
|
||||||
|
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||||
|
"mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
|
||||||
|
"mit": "https://choosealicense.com/licenses/mit/",
|
||||||
|
"apache2.0": "https://choosealicense.com/licenses/apache-2.0/",
|
||||||
|
"apache2": "https://choosealicense.com/licenses/apache-2.0/",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""Manage TTS models defined in .models.json.
|
"""Manage TTS models defined in .models.json.
|
||||||
It provides an interface to list and download
|
It provides an interface to list and download
|
||||||
|
@ -108,6 +118,21 @@ class ModelManager(object):
|
||||||
for dataset in self.models_dict[model_type][lang]:
|
for dataset in self.models_dict[model_type][lang]:
|
||||||
print(f" >: {model_type}/{lang}/{dataset}")
|
print(f" >: {model_type}/{lang}/{dataset}")
|
||||||
|
|
||||||
|
def print_model_license(self, model_item: Dict):
|
||||||
|
"""Print the license of a model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_item (dict): model item in the models.json
|
||||||
|
"""
|
||||||
|
if "license" in model_item and model_item["license"].strip() != "":
|
||||||
|
print(f" > Model's license - {model_item['license']}")
|
||||||
|
if model_item['license'].lower() in LICENSE_URLS:
|
||||||
|
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
|
||||||
|
else:
|
||||||
|
print(f" > Check https://opensource.org/licenses for more info.")
|
||||||
|
else:
|
||||||
|
print(" > Model's license - No license information available")
|
||||||
|
|
||||||
def download_model(self, model_name):
|
def download_model(self, model_name):
|
||||||
"""Download model files given the full model name.
|
"""Download model files given the full model name.
|
||||||
Model name is in the format
|
Model name is in the format
|
||||||
|
@ -135,6 +160,7 @@ class ModelManager(object):
|
||||||
print(f" > Downloading model to {output_path}")
|
print(f" > Downloading model to {output_path}")
|
||||||
# download from github release
|
# download from github release
|
||||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||||
|
self.print_model_license(model_item=model_item)
|
||||||
# find downloaded files
|
# find downloaded files
|
||||||
output_model_path, output_config_path = self._find_files(output_path)
|
output_model_path, output_config_path = self._find_files(output_path)
|
||||||
# update paths in the config.json
|
# update paths in the config.json
|
||||||
|
|
Loading…
Reference in New Issue