Print model license while downloading

This commit is contained in:
Eren Gölge 2022-04-19 11:26:02 +02:00
parent c18100d112
commit 5cc8e48c3a
1 changed files with 27 additions and 1 deletions

View File

@ -4,7 +4,7 @@ import os
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
from typing import Tuple from typing import Dict, Tuple
import requests import requests
@ -12,6 +12,16 @@ from TTS.config import load_config
from TTS.utils.generic_utils import get_user_data_dir from TTS.utils.generic_utils import get_user_data_dir
LICENSE_URLS = {
"cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
"mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
"mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
"mit": "https://choosealicense.com/licenses/mit/",
"apache2.0": "https://choosealicense.com/licenses/apache-2.0/",
"apache2": "https://choosealicense.com/licenses/apache-2.0/",
}
class ModelManager(object): class ModelManager(object):
"""Manage TTS models defined in .models.json. """Manage TTS models defined in .models.json.
It provides an interface to list and download It provides an interface to list and download
@ -108,6 +118,21 @@ class ModelManager(object):
for dataset in self.models_dict[model_type][lang]: for dataset in self.models_dict[model_type][lang]:
print(f" >: {model_type}/{lang}/{dataset}") print(f" >: {model_type}/{lang}/{dataset}")
def print_model_license(self, model_item: Dict):
"""Print the license of a model
Args:
model_item (dict): model item in the models.json
"""
if "license" in model_item and model_item["license"].strip() != "":
print(f" > Model's license - {model_item['license']}")
if model_item['license'].lower() in LICENSE_URLS:
print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
else:
print(f" > Check https://opensource.org/licenses for more info.")
else:
print(" > Model's license - No license information available")
def download_model(self, model_name): def download_model(self, model_name):
"""Download model files given the full model name. """Download model files given the full model name.
Model name is in the format Model name is in the format
@ -135,6 +160,7 @@ class ModelManager(object):
print(f" > Downloading model to {output_path}") print(f" > Downloading model to {output_path}")
# download from github release # download from github release
self._download_zip_file(model_item["github_rls_url"], output_path) self._download_zip_file(model_item["github_rls_url"], output_path)
self.print_model_license(model_item=model_item)
# find downloaded files # find downloaded files
output_model_path, output_config_path = self._find_files(output_path) output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json # update paths in the config.json