mirror of https://github.com/coqui-ai/TTS.git
Fix `ModelManager` model download
This commit is contained in:
parent
9352cb4136
commit
196876feb1
TTS
|
@ -98,8 +98,7 @@ class SpeedySpeech(BaseTTS):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if "characters" in config:
|
if "characters" in config:
|
||||||
chars, self.config = self.get_characters(config)
|
chars, self.config, self.num_chars = self.get_characters(config)
|
||||||
self.num_chars = len(chars)
|
|
||||||
|
|
||||||
self.length_scale = (
|
self.length_scale = (
|
||||||
float(config.model_args.length_scale)
|
float(config.model_args.length_scale)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile, rmtree
|
||||||
|
|
||||||
import gdown
|
import gdown
|
||||||
import requests
|
import requests
|
||||||
|
@ -83,7 +83,7 @@ class ModelManager(object):
|
||||||
'type/language/dataset/model'
|
'type/language/dataset/model'
|
||||||
e.g. 'tts_model/en/ljspeech/tacotron'
|
e.g. 'tts_model/en/ljspeech/tacotron'
|
||||||
|
|
||||||
Every model must have the following files
|
Every model must have the following files:
|
||||||
- *.pth.tar : pytorch model checkpoint file.
|
- *.pth.tar : pytorch model checkpoint file.
|
||||||
- config.json : model config file.
|
- config.json : model config file.
|
||||||
- scale_stats.npy (if exist): scale values for preprocessing.
|
- scale_stats.npy (if exist): scale values for preprocessing.
|
||||||
|
@ -101,11 +101,7 @@ class ModelManager(object):
|
||||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||||
output_model_path = os.path.join(output_path, "model_file.pth.tar")
|
output_model_path = os.path.join(output_path, "model_file.pth.tar")
|
||||||
output_config_path = os.path.join(output_path, "config.json")
|
output_config_path = os.path.join(output_path, "config.json")
|
||||||
# NOTE : band-aid for removing phoneme support
|
|
||||||
# if "needs_phonemizer" in model_item and model_item["needs_phonemizer"]:
|
|
||||||
# raise RuntimeError(
|
|
||||||
# " [!] Use 🐸TTS <= v0.0.13 for this model. Current version does not support phoneme based models."
|
|
||||||
# )
|
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(output_path):
|
||||||
print(f" > {model_name} is already downloaded.")
|
print(f" > {model_name} is already downloaded.")
|
||||||
else:
|
else:
|
||||||
|
@ -116,7 +112,6 @@ class ModelManager(object):
|
||||||
# download files to the output path
|
# download files to the output path
|
||||||
if self._check_dict_key(model_item, "github_rls_url"):
|
if self._check_dict_key(model_item, "github_rls_url"):
|
||||||
# download from github release
|
# download from github release
|
||||||
# TODO: pass output_path
|
|
||||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||||
else:
|
else:
|
||||||
# download from gdrive
|
# download from gdrive
|
||||||
|
@ -146,15 +141,20 @@ class ModelManager(object):
|
||||||
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _download_zip_file(file_url, output):
|
def _download_zip_file(file_url, output_folder):
|
||||||
"""Download the github releases"""
|
"""Download the github releases"""
|
||||||
|
# download the file
|
||||||
r = requests.get(file_url)
|
r = requests.get(file_url)
|
||||||
|
# extract the file
|
||||||
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
||||||
z.extractall(output)
|
z.extractall(output_folder)
|
||||||
|
# move the files to the outer path
|
||||||
for file_path in z.namelist()[1:]:
|
for file_path in z.namelist()[1:]:
|
||||||
src_path = os.path.join(output, file_path)
|
src_path = os.path.join(output_folder, file_path)
|
||||||
dst_path = os.path.join(output, os.path.basename(file_path))
|
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||||
copyfile(src_path, dst_path)
|
copyfile(src_path, dst_path)
|
||||||
|
# remove the extracted folder
|
||||||
|
rmtree(os.path.join(output_folder, z.namelist()[0]))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_dict_key(my_dict, key):
|
def _check_dict_key(my_dict, key):
|
||||||
|
|
Loading…
Reference in New Issue