From 7133f8f47d6c0ed0ce4c3beefeb8112ce94d7f6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 19 Apr 2022 14:18:49 +0200 Subject: [PATCH] Print Model's license when downloading (#1512) * Print model license while downloading * Make style * Add a new license link * Make style --- TTS/utils/manage.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 674d5a47..0ef3675b 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -4,13 +4,23 @@ import os import zipfile from pathlib import Path from shutil import copyfile, rmtree -from typing import Tuple +from typing import Dict, Tuple import requests from TTS.config import load_config from TTS.utils.generic_utils import get_user_data_dir +LICENSE_URLS = { + "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", + "mpl": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl2": "https://www.mozilla.org/en-US/MPL/2.0/", + "mit": "https://choosealicense.com/licenses/mit/", + "apache2.0": "https://choosealicense.com/licenses/apache-2.0/", + "apache2": "https://choosealicense.com/licenses/apache-2.0/", + "cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/", +} + class ModelManager(object): """Manage TTS models defined in .models.json. @@ -108,6 +118,22 @@ class ModelManager(object): for dataset in self.models_dict[model_type][lang]: print(f" >: {model_type}/{lang}/{dataset}") + @staticmethod + def print_model_license(model_item: Dict): + """Print the license of a model + + Args: + model_item (dict): model item in the models.json + """ + if "license" in model_item and model_item["license"].strip() != "": + print(f" > Model's license - {model_item['license']}") + if model_item["license"].lower() in LICENSE_URLS: + print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.") + else: + print(" > Check https://opensource.org/licenses for more info.") + else: + print(" > Model's license - No license information available") + def download_model(self, model_name): """Download model files given the full model name. Model name is in the format @@ -135,6 +161,7 @@ class ModelManager(object): print(f" > Downloading model to {output_path}") # download from github release self._download_zip_file(model_item["github_rls_url"], output_path) + self.print_model_license(model_item=model_item) # find downloaded files output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json