Fix `ModelManager` model download

This commit is contained in:
Eren Gölge 2021-07-02 10:47:05 +02:00
parent 9352cb4136
commit 196876feb1
2 changed files with 13 additions and 14 deletions
TTS

View File

@ -98,8 +98,7 @@ class SpeedySpeech(BaseTTS):
self.config = config self.config = config
if "characters" in config: if "characters" in config:
chars, self.config = self.get_characters(config) chars, self.config, self.num_chars = self.get_characters(config)
self.num_chars = len(chars)
self.length_scale = ( self.length_scale = (
float(config.model_args.length_scale) float(config.model_args.length_scale)

View File

@ -3,7 +3,7 @@ import json
import os import os
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile, rmtree
import gdown import gdown
import requests import requests
@ -83,7 +83,7 @@ class ModelManager(object):
'type/language/dataset/model' 'type/language/dataset/model'
e.g. 'tts_model/en/ljspeech/tacotron' e.g. 'tts_model/en/ljspeech/tacotron'
Every model must have the following files Every model must have the following files:
- *.pth.tar : pytorch model checkpoint file. - *.pth.tar : pytorch model checkpoint file.
- config.json : model config file. - config.json : model config file.
- scale_stats.npy (if exist): scale values for preprocessing. - scale_stats.npy (if exist): scale values for preprocessing.
@ -101,11 +101,7 @@ class ModelManager(object):
output_path = os.path.join(self.output_prefix, model_full_name) output_path = os.path.join(self.output_prefix, model_full_name)
output_model_path = os.path.join(output_path, "model_file.pth.tar") output_model_path = os.path.join(output_path, "model_file.pth.tar")
output_config_path = os.path.join(output_path, "config.json") output_config_path = os.path.join(output_path, "config.json")
# NOTE : band-aid for removing phoneme support
# if "needs_phonemizer" in model_item and model_item["needs_phonemizer"]:
# raise RuntimeError(
# " [!] Use 🐸TTS <= v0.0.13 for this model. Current version does not support phoneme based models."
# )
if os.path.exists(output_path): if os.path.exists(output_path):
print(f" > {model_name} is already downloaded.") print(f" > {model_name} is already downloaded.")
else: else:
@ -116,7 +112,6 @@ class ModelManager(object):
# download files to the output path # download files to the output path
if self._check_dict_key(model_item, "github_rls_url"): if self._check_dict_key(model_item, "github_rls_url"):
# download from github release # download from github release
# TODO: pass output_path
self._download_zip_file(model_item["github_rls_url"], output_path) self._download_zip_file(model_item["github_rls_url"], output_path)
else: else:
# download from gdrive # download from gdrive
@ -146,15 +141,20 @@ class ModelManager(object):
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False) gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
@staticmethod @staticmethod
def _download_zip_file(file_url, output): def _download_zip_file(file_url, output_folder):
"""Download the github releases""" """Download the github releases"""
# download the file
r = requests.get(file_url) r = requests.get(file_url)
# extract the file
with zipfile.ZipFile(io.BytesIO(r.content)) as z: with zipfile.ZipFile(io.BytesIO(r.content)) as z:
z.extractall(output) z.extractall(output_folder)
# move the files to the outer path
for file_path in z.namelist()[1:]: for file_path in z.namelist()[1:]:
src_path = os.path.join(output, file_path) src_path = os.path.join(output_folder, file_path)
dst_path = os.path.join(output, os.path.basename(file_path)) dst_path = os.path.join(output_folder, os.path.basename(file_path))
copyfile(src_path, dst_path) copyfile(src_path, dst_path)
# remove the extracted folder
rmtree(os.path.join(output_folder, z.namelist()[0]))
@staticmethod @staticmethod
def _check_dict_key(my_dict, key): def _check_dict_key(my_dict, key):