Print Model's license when downloading (#1512)

* Print model license while downloading

* Make style

* Add a new license link

* Make style
This commit is contained in:
Eren Gölge 2022-04-19 14:18:49 +02:00 committed by GitHub
parent 4953636b14
commit 7133f8f47d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 1 deletions

View File

@ -4,13 +4,23 @@ import os
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from typing import Tuple from typing import Dict, Tuple
import requests import requests
from TTS.config import load_config from TTS.config import load_config
from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.generic_utils import get_user_data_dir
LICENSE_URLS = {
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
"mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
"mit": "https://choosealicense.com/licenses/mit/",
"apache2.0": "https://choosealicense.com/licenses/apache-2.0/",
"apache2": "https://choosealicense.com/licenses/apache-2.0/",
"cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/",
}
class ModelManager(object): class ModelManager(object):
"""Manage TTS models defined in .models.json. """Manage TTS models defined in .models.json.
@ -108,6 +118,22 @@ class ModelManager(object):
for dataset in self.models_dict[model_type][lang]: for dataset in self.models_dict[model_type][lang]:
print(f" >: {model_type}/{lang}/{dataset}") print(f" >: {model_type}/{lang}/{dataset}")
@staticmethod
def print_model_license(model_item: Dict):
"""Print the license of a model
Args:
model_item (dict): model item in the models.json
"""
if "license" in model_item and model_item["license"].strip() != "":
print(f" > Model's license - {model_item['license']}")
if model_item["license"].lower() in LICENSE_URLS:
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
else:
print(" > Check https://opensource.org/licenses for more info.")
else:
print(" > Model's license - No license information available")
def download_model(self, model_name): def download_model(self, model_name):
"""Download model files given the full model name. """Download model files given the full model name.
Model name is in the format Model name is in the format
@ -135,6 +161,7 @@ class ModelManager(object):
print(f" > Downloading model to {output_path}") print(f" > Downloading model to {output_path}")
# download from github release # download from github release
self._download_zip_file(model_item["github_rls_url"], output_path) self._download_zip_file(model_item["github_rls_url"], output_path)
self.print_model_license(model_item=model_item)
# find downloaded files # find downloaded files
output_model_path, output_config_path = self._find_files(output_path) output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json # update paths in the config.json