mirror of https://github.com/coqui-ai/TTS.git
Find model files
This commit is contained in:
parent
bfee55af2b
commit
160de0222d
|
@ -3,6 +3,7 @@ import json
|
|||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from shutil import copyfile, rmtree
|
||||
|
||||
import requests
|
||||
|
@ -139,8 +140,32 @@ class ModelManager(object):
|
|||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||
# update paths in the config.json
|
||||
self._update_paths(output_path, output_config_path)
|
||||
# find downloaded files
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
return output_model_path, output_config_path, model_item
|
||||
|
||||
def _find_files(self, output_path:str) -> Tuple[str, str]:
|
||||
"""Find the model and config files in the output path
|
||||
|
||||
Args:
|
||||
output_path (str): path to the model files
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: path to the model file and config file
|
||||
"""
|
||||
model_file = None
|
||||
config_file = None
|
||||
for file_name in os.listdir(output_path):
|
||||
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
|
||||
model_file = os.path.join(output_path, file_name)
|
||||
elif file_name == "config.json":
|
||||
config_file = os.path.join(output_path, file_name)
|
||||
if model_file is None:
|
||||
raise ValueError(" [!] Model file not found in the output path")
|
||||
if config_file is None:
|
||||
raise ValueError(" [!] Config file not found in the output path")
|
||||
return model_file, config_file
|
||||
|
||||
def _update_paths(self, output_path: str, config_path: str) -> None:
|
||||
"""Update paths for certain files in config.json after download.
|
||||
|
||||
|
|
Loading…
Reference in New Issue