mirror of https://github.com/coqui-ai/TTS.git
Fixup
This commit is contained in:
parent
af16348cfb
commit
724360d71a
|
@ -130,7 +130,7 @@ class CS_API:
|
|||
for speaker in self.speakers:
|
||||
if speaker.name == name:
|
||||
return speaker
|
||||
raise ValueError(f"Speaker {name} not found.")
|
||||
raise ValueError(f"Speaker {name} not found in {self.speakers}")
|
||||
|
||||
def id_to_speaker(self, speaker_id):
|
||||
for speaker in self.speakers:
|
||||
|
@ -346,7 +346,7 @@ class TTS:
|
|||
|
||||
def download_model_by_name(self, model_name: str):
|
||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||
if "fairseq" in model_name or isinstance(model_item["github_rls_url"], list):
|
||||
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)):
|
||||
# return model directory if there are multiple files
|
||||
# we assume that the model knows how to load itself
|
||||
return None, None, None, None, model_path
|
||||
|
|
|
@ -252,6 +252,24 @@ class ModelManager(object):
|
|||
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
|
||||
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
|
||||
|
||||
def _set_model_item(self, model_name):
|
||||
# fetch model info from the dict
|
||||
model_type, lang, dataset, model = model_name.split("/")
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
if "fairseq" in model_name:
|
||||
model_item = {
|
||||
"model_type": "tts_models",
|
||||
"license": "CC BY-NC 4.0",
|
||||
"default_vocoder": None,
|
||||
"author": "fairseq",
|
||||
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
||||
}
|
||||
else:
|
||||
# get model from models.json
|
||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||
model_item["model_type"] = model_type
|
||||
return model_item, model_full_name, model
|
||||
|
||||
def download_model(self, model_name):
|
||||
"""Download model files given the full model name.
|
||||
Model name is in the format
|
||||
|
@ -266,10 +284,7 @@ class ModelManager(object):
|
|||
Args:
|
||||
model_name (str): model name as explained above.
|
||||
"""
|
||||
model_item = None
|
||||
# fetch model info from the dict
|
||||
model_type, lang, dataset, model = model_name.split("/")
|
||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||
model_item, model_full_name, model = self._set_model_item(model_name)
|
||||
# set the model specific output path
|
||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
if os.path.exists(output_path):
|
||||
|
@ -280,17 +295,7 @@ class ModelManager(object):
|
|||
# download from fairseq
|
||||
if "fairseq" in model_name:
|
||||
self.download_fairseq_model(model_name, output_path)
|
||||
model_item = {
|
||||
"model_type": "tts_models",
|
||||
"license": "CC BY-NC 4.0",
|
||||
"default_vocoder": None,
|
||||
"author": "fairseq",
|
||||
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
||||
}
|
||||
else:
|
||||
# get model from models.json
|
||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||
model_item["model_type"] = model_type
|
||||
# download from github release
|
||||
if isinstance(model_item["github_rls_url"], list):
|
||||
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
|
|
|
@ -60,7 +60,7 @@ if is_coqui_available:
|
|||
self.assertIsNone(tts.languages)
|
||||
|
||||
def test_studio_model(self):
|
||||
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio")
|
||||
tts = TTS(model_name="coqui_studio/en/Zacharie Aimilios/coqui_studio")
|
||||
tts.tts_to_file(text="This is a test.")
|
||||
|
||||
# check speed > 2.0 raises error
|
||||
|
|
Loading…
Reference in New Issue