This commit is contained in:
Eren Gölge 2023-06-05 09:39:04 +02:00
parent af16348cfb
commit 724360d71a
3 changed files with 22 additions and 17 deletions

View File

@ -130,7 +130,7 @@ class CS_API:
for speaker in self.speakers:
if speaker.name == name:
return speaker
raise ValueError(f"Speaker {name} not found.")
raise ValueError(f"Speaker {name} not found in {self.speakers}")
def id_to_speaker(self, speaker_id):
for speaker in self.speakers:
@ -346,7 +346,7 @@ class TTS:
def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or isinstance(model_item["github_rls_url"], list):
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path

View File

@ -252,6 +252,24 @@ class ModelManager(object):
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
def _set_model_item(self, model_name):
# fetch model info from the dict
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
if "fairseq" in model_name:
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
"default_vocoder": None,
"author": "fairseq",
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
else:
# get model from models.json
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
return model_item, model_full_name, model
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
@ -266,10 +284,7 @@ class ModelManager(object):
Args:
model_name (str): model name as explained above.
"""
model_item = None
# fetch model info from the dict
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
model_item, model_full_name, model = self._set_model_item(model_name)
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
@ -280,17 +295,7 @@ class ModelManager(object):
# download from fairseq
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
"default_vocoder": None,
"author": "fairseq",
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
else:
# get model from models.json
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
# download from github release
if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)

View File

@ -60,7 +60,7 @@ if is_coqui_available:
self.assertIsNone(tts.languages)
def test_studio_model(self):
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio")
tts = TTS(model_name="coqui_studio/en/Zacharie Aimilios/coqui_studio")
tts.tts_to_file(text="This is a test.")
# check speed > 2.0 raises error