Make xtts runnable by version names

This commit is contained in:
Eren G??lge 2023-11-27 14:15:16 +01:00
parent 11ec9f7471
commit 2fd8cf3d94
2 changed files with 44 additions and 2 deletions

View File

@ -80,6 +80,8 @@ class TTS(nn.Module):
self.load_tts_model_by_name(model_name, gpu)
elif "voice_conversion_models" in model_name:
self.load_vc_model_by_name(model_name, gpu)
else:
self.load_model_by_name(model_name, gpu)
if model_path:
self.load_tts_model_by_path(
@ -149,6 +151,15 @@ class TTS(nn.Module):
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
return model_path, config_path, vocoder_path, vocoder_config_path, None
def load_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of the 🐸TTS models by name.
Args:
model_name (str): Model name to load. You can list models by ```tts.models```.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.load_tts_model_by_name(model_name, gpu)
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of the voice conversion models by name.

View File

@ -1,5 +1,6 @@
import json
import os
import re
import tarfile
import zipfile
from pathlib import Path
@ -276,13 +277,15 @@ class ModelManager(object):
model_item["model_url"] = model_item["hf_url"]
elif "fairseq" in model_item["model_name"]:
model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
elif "xtts" in model_item["model_name"]:
model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
return model_item
def _set_model_item(self, model_name):
# fetch model info from the dict
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
if "fairseq" in model_name:
model_type = "tts_models"
lang = model_name.split("/")[1]
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
@ -291,10 +294,38 @@ class ModelManager(object):
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
model_item["model_name"] = model_name
elif "xtts" in model_name and len(model_name.split("/")) != 4:
# loading xtts models with only model name (e.g. xtts_v2.0.2)
# check model name has the version number with regex
version_regex = r"v\d+\.\d+\.\d+"
if re.search(version_regex, model_name):
model_version = model_name.split("_")[-1]
else:
model_version = "main"
model_type = "tts_models"
lang = "multilingual"
dataset = "multi-dataset"
model = model_name
model_item = {
"default_vocoder": None,
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": True,
"hf_url": [
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5"
],
}
print(model_item)
else:
# get model from models.json
model_type, lang, dataset, model = model_name.split("/")
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
model_item = self.set_model_url(model_item)
return model_item, model_full_name, model, md5hash