diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 3566cf2f..dd397687 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -3,6 +3,7 @@ import json import os import zipfile from pathlib import Path +from typing import Tuple from shutil import copyfile, rmtree import requests @@ -139,8 +140,32 @@ class ModelManager(object): self._download_zip_file(model_item["github_rls_url"], output_path) # update paths in the config.json self._update_paths(output_path, output_config_path) + # find downloaded files + output_model_path, output_config_path = self._find_files(output_path) return output_model_path, output_config_path, model_item + def _find_files(self, output_path:str) -> Tuple[str, str]: + """Find the model and config files in the output path + + Args: + output_path (str): path to the model files + + Returns: + Tuple[str, str]: path to the model file and config file + """ + model_file = None + config_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: + model_file = os.path.join(output_path, file_name) + elif file_name == "config.json": + config_file = os.path.join(output_path, file_name) + if model_file is None: + raise ValueError(" [!] Model file not found in the output path") + if config_file is None: + raise ValueError(" [!] Config file not found in the output path") + return model_file, config_file + def _update_paths(self, output_path: str, config_path: str) -> None: """Update paths for certain files in config.json after download.