Print Model's license when downloading (#1512)

* Print model license while downloading

* Make style

* Add a new license link

* Make style
This commit is contained in:
Eren Gölge 2022-04-19 14:18:49 +02:00 committed by GitHub
parent 4953636b14
commit 7133f8f47d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 28 additions and 1 deletions

View File

@ -4,13 +4,23 @@ import os
import zipfile
from pathlib import Path
from shutil import copyfile, rmtree
from typing import Tuple
from typing import Dict, Tuple
import requests
from TTS.config import load_config
from TTS.utils.generic_utils import get_user_data_dir
LICENSE_URLS = {
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
"mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
"mit": "https://choosealicense.com/licenses/mit/",
"apache2.0": "https://choosealicense.com/licenses/apache-2.0/",
"apache2": "https://choosealicense.com/licenses/apache-2.0/",
"cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/",
}
class ModelManager(object):
"""Manage TTS models defined in .models.json.
@ -108,6 +118,22 @@ class ModelManager(object):
for dataset in self.models_dict[model_type][lang]:
print(f" >: {model_type}/{lang}/{dataset}")
@staticmethod
def print_model_license(model_item: Dict):
"""Print the license of a model
Args:
model_item (dict): model item in the models.json
"""
if "license" in model_item and model_item["license"].strip() != "":
print(f" > Model's license - {model_item['license']}")
if model_item["license"].lower() in LICENSE_URLS:
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
else:
print(" > Check https://opensource.org/licenses for more info.")
else:
print(" > Model's license - No license information available")
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
@ -135,6 +161,7 @@ class ModelManager(object):
print(f" > Downloading model to {output_path}")
# download from github release
self._download_zip_file(model_item["github_rls_url"], output_path)
self.print_model_license(model_item=model_item)
# find downloaded files
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json