From 2fd8cf3d94a4c3dd5567e8f444b541a14f116e50 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 27 Nov 2023 14:15:16 +0100 Subject: [PATCH] Make xtts runnable by version names --- TTS/api.py | 11 +++++++++++ TTS/utils/manage.py | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 3331f30e..c207cb71 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -80,6 +80,8 @@ class TTS(nn.Module): self.load_tts_model_by_name(model_name, gpu) elif "voice_conversion_models" in model_name: self.load_vc_model_by_name(model_name, gpu) + else: + self.load_model_by_name(model_name, gpu) if model_path: self.load_tts_model_by_path( @@ -149,6 +151,15 @@ class TTS(nn.Module): vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"]) return model_path, config_path, vocoder_path, vocoder_config_path, None + def load_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of the 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + self.load_tts_model_by_name(model_name, gpu) + def load_vc_model_by_name(self, model_name: str, gpu: bool = False): """Load one of the voice conversion models by name. diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index d3eb8104..bdfc2d95 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -1,5 +1,6 @@ import json import os +import re import tarfile import zipfile from pathlib import Path @@ -276,13 +277,15 @@ class ModelManager(object): model_item["model_url"] = model_item["hf_url"] elif "fairseq" in model_item["model_name"]: model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/" + elif "xtts" in model_item["model_name"]: + model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/" return model_item def _set_model_item(self, model_name): # fetch model info from the dict - model_type, lang, dataset, model = model_name.split("/") - model_full_name = f"{model_type}--{lang}--{dataset}--{model}" if "fairseq" in model_name: + model_type = "tts_models" + lang = model_name.split("/")[1] model_item = { "model_type": "tts_models", "license": "CC BY-NC 4.0", @@ -291,10 +294,38 @@ class ModelManager(object): "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", } model_item["model_name"] = model_name + elif "xtts" in model_name and len(model_name.split("/")) != 4: + # loading xtts models with only model name (e.g. xtts_v2.0.2) + # check model name has the version number with regex + version_regex = r"v\d+\.\d+\.\d+" + if re.search(version_regex, model_name): + model_version = model_name.split("_")[-1] + else: + model_version = "main" + model_type = "tts_models" + lang = "multilingual" + dataset = "multi-dataset" + model = model_name + model_item = { + "default_vocoder": None, + "license": "CPML", + "contact": "info@coqui.ai", + "tos_required": True, + "hf_url": [ + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5" + ], + } + print(model_item) else: # get model from models.json + model_type, lang, dataset, model = model_name.split("/") model_item = self.models_dict[model_type][lang][dataset][model] model_item["model_type"] = model_type + + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" md5hash = model_item["model_hash"] if "model_hash" in model_item else None model_item = self.set_model_url(model_item) return model_item, model_full_name, model, md5hash