diff --git a/TTS/api.py b/TTS/api.py index 540d5494..904653e8 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -130,7 +130,7 @@ class CS_API: for speaker in self.speakers: if speaker.name == name: return speaker - raise ValueError(f"Speaker {name} not found.") + raise ValueError(f"Speaker {name} not found in {self.speakers}") def id_to_speaker(self, speaker_id): for speaker in self.speakers: @@ -346,7 +346,7 @@ class TTS: def download_model_by_name(self, model_name: str): model_path, config_path, model_item = self.manager.download_model(model_name) - if "fairseq" in model_name or isinstance(model_item["github_rls_url"], list): + if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)): # return model directory if there are multiple files # we assume that the model knows how to load itself return None, None, None, None, model_path diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index a6a3705d..98e48a2a 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -252,6 +252,24 @@ class ModelManager(object): model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz") self._download_tar_file(model_download_uri, output_path, self.progress_bar) + def _set_model_item(self, model_name): + # fetch model info from the dict + model_type, lang, dataset, model = model_name.split("/") + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + if "fairseq" in model_name: + model_item = { + "model_type": "tts_models", + "license": "CC BY-NC 4.0", + "default_vocoder": None, + "author": "fairseq", + "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", + } + else: + # get model from models.json + model_item = self.models_dict[model_type][lang][dataset][model] + model_item["model_type"] = model_type + return model_item, model_full_name, model + def download_model(self, model_name): """Download model files given the full model name. Model name is in the format @@ -266,10 +284,7 @@ class ModelManager(object): Args: model_name (str): model name as explained above. """ - model_item = None - # fetch model info from the dict - model_type, lang, dataset, model = model_name.split("/") - model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + model_item, model_full_name, model = self._set_model_item(model_name) # set the model specific output path output_path = os.path.join(self.output_prefix, model_full_name) if os.path.exists(output_path): @@ -280,17 +295,7 @@ class ModelManager(object): # download from fairseq if "fairseq" in model_name: self.download_fairseq_model(model_name, output_path) - model_item = { - "model_type": "tts_models", - "license": "CC BY-NC 4.0", - "default_vocoder": None, - "author": "fairseq", - "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", - } else: - # get model from models.json - model_item = self.models_dict[model_type][lang][dataset][model] - model_item["model_type"] = model_type # download from github release if isinstance(model_item["github_rls_url"], list): self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) diff --git a/tests/inference_tests/test_python_api.py b/tests/inference_tests/test_python_api.py index 07e67967..2025fcd9 100644 --- a/tests/inference_tests/test_python_api.py +++ b/tests/inference_tests/test_python_api.py @@ -60,7 +60,7 @@ if is_coqui_available: self.assertIsNone(tts.languages) def test_studio_model(self): - tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio") + tts = TTS(model_name="coqui_studio/en/Zacharie Aimilios/coqui_studio") tts.tts_to_file(text="This is a test.") # check speed > 2.0 raises error