mirror of https://github.com/coqui-ai/TTS.git
Make xtts runnable by version names
This commit is contained in:
parent
11ec9f7471
commit
2fd8cf3d94
11
TTS/api.py
11
TTS/api.py
|
@ -80,6 +80,8 @@ class TTS(nn.Module):
|
||||||
self.load_tts_model_by_name(model_name, gpu)
|
self.load_tts_model_by_name(model_name, gpu)
|
||||||
elif "voice_conversion_models" in model_name:
|
elif "voice_conversion_models" in model_name:
|
||||||
self.load_vc_model_by_name(model_name, gpu)
|
self.load_vc_model_by_name(model_name, gpu)
|
||||||
|
else:
|
||||||
|
self.load_model_by_name(model_name, gpu)
|
||||||
|
|
||||||
if model_path:
|
if model_path:
|
||||||
self.load_tts_model_by_path(
|
self.load_tts_model_by_path(
|
||||||
|
@ -149,6 +151,15 @@ class TTS(nn.Module):
|
||||||
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
|
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
|
||||||
return model_path, config_path, vocoder_path, vocoder_config_path, None
|
return model_path, config_path, vocoder_path, vocoder_config_path, None
|
||||||
|
|
||||||
|
def load_model_by_name(self, model_name: str, gpu: bool = False):
|
||||||
|
"""Load one of the 🐸TTS models by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): Model name to load. You can list models by ```tts.models```.
|
||||||
|
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||||
|
"""
|
||||||
|
self.load_tts_model_by_name(model_name, gpu)
|
||||||
|
|
||||||
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
|
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
|
||||||
"""Load one of the voice conversion models by name.
|
"""Load one of the voice conversion models by name.
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tarfile
|
import tarfile
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -276,13 +277,15 @@ class ModelManager(object):
|
||||||
model_item["model_url"] = model_item["hf_url"]
|
model_item["model_url"] = model_item["hf_url"]
|
||||||
elif "fairseq" in model_item["model_name"]:
|
elif "fairseq" in model_item["model_name"]:
|
||||||
model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
|
model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
|
||||||
|
elif "xtts" in model_item["model_name"]:
|
||||||
|
model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
|
||||||
return model_item
|
return model_item
|
||||||
|
|
||||||
def _set_model_item(self, model_name):
|
def _set_model_item(self, model_name):
|
||||||
# fetch model info from the dict
|
# fetch model info from the dict
|
||||||
model_type, lang, dataset, model = model_name.split("/")
|
|
||||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
|
||||||
if "fairseq" in model_name:
|
if "fairseq" in model_name:
|
||||||
|
model_type = "tts_models"
|
||||||
|
lang = model_name.split("/")[1]
|
||||||
model_item = {
|
model_item = {
|
||||||
"model_type": "tts_models",
|
"model_type": "tts_models",
|
||||||
"license": "CC BY-NC 4.0",
|
"license": "CC BY-NC 4.0",
|
||||||
|
@ -291,10 +294,38 @@ class ModelManager(object):
|
||||||
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
||||||
}
|
}
|
||||||
model_item["model_name"] = model_name
|
model_item["model_name"] = model_name
|
||||||
|
elif "xtts" in model_name and len(model_name.split("/")) != 4:
|
||||||
|
# loading xtts models with only model name (e.g. xtts_v2.0.2)
|
||||||
|
# check model name has the version number with regex
|
||||||
|
version_regex = r"v\d+\.\d+\.\d+"
|
||||||
|
if re.search(version_regex, model_name):
|
||||||
|
model_version = model_name.split("_")[-1]
|
||||||
|
else:
|
||||||
|
model_version = "main"
|
||||||
|
model_type = "tts_models"
|
||||||
|
lang = "multilingual"
|
||||||
|
dataset = "multi-dataset"
|
||||||
|
model = model_name
|
||||||
|
model_item = {
|
||||||
|
"default_vocoder": None,
|
||||||
|
"license": "CPML",
|
||||||
|
"contact": "info@coqui.ai",
|
||||||
|
"tos_required": True,
|
||||||
|
"hf_url": [
|
||||||
|
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
|
||||||
|
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
|
||||||
|
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
|
||||||
|
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5"
|
||||||
|
],
|
||||||
|
}
|
||||||
|
print(model_item)
|
||||||
else:
|
else:
|
||||||
# get model from models.json
|
# get model from models.json
|
||||||
|
model_type, lang, dataset, model = model_name.split("/")
|
||||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||||
model_item["model_type"] = model_type
|
model_item["model_type"] = model_type
|
||||||
|
|
||||||
|
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||||
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
|
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
|
||||||
model_item = self.set_model_url(model_item)
|
model_item = self.set_model_url(model_item)
|
||||||
return model_item, model_full_name, model, md5hash
|
return model_item, model_full_name, model, md5hash
|
||||||
|
|
Loading…
Reference in New Issue