Managing fairseq models and docs for API

This commit is contained in:
Eren G??lge 2023-05-23 17:46:16 +02:00
parent 92d4823ad4
commit bb46727733
5 changed files with 117 additions and 20 deletions

View File

@ -264,6 +264,10 @@ class TTS:
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html):
>>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True)
>>> tts.tts_to_file("This is a test.", file_path="output.wav")
Args:
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
model_path (str, optional): Path to the model checkpoint. Defaults to None.
@ -342,7 +346,7 @@ class TTS:
def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name)
if isinstance(model_item["github_rls_url"], list):
if "fairseq" in model_name or isinstance(model_item["github_rls_url"], list):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path

View File

@ -1726,7 +1726,6 @@ class Vits(BaseTTS):
def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False):
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
Performs some changes for compatibility.
Args:
@ -1736,6 +1735,7 @@ class Vits(BaseTTS):
"""
import json
self.disc = None
# set paths
config_file = os.path.join(checkpoint_dir, "config.json")
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
@ -1974,7 +1974,7 @@ class FairseqVocab(BaseVocabulary):
@vocab.setter
def vocab(self, vocab_file):
self._vocab = [x.replace("\n", "") for x in open(vocab_file).readlines()]
self._vocab = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()]
self.blank = self._vocab[0]
print(self._vocab)
self.pad = " "

View File

@ -1,5 +1,6 @@
import json
import os
import tarfile
import zipfile
from pathlib import Path
from shutil import copyfile, rmtree
@ -245,6 +246,12 @@ class ModelManager(object):
else:
print(" > Model's license - No license information available")
def download_fairseq_model(self, model_name, output_path):
URI_PREFIX = "https://dl.fbaipublicfiles.com/mms/tts/"
model_type, lang, dataset, model = model_name.split("/")
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
def download_model(self, model_name):
"""Download model files given the full model name.
Model name is in the format
@ -259,11 +266,10 @@ class ModelManager(object):
Args:
model_name (str): model name as explained above.
"""
model_item = None
# fetch model info from the dict
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
# set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path):
@ -271,16 +277,30 @@ class ModelManager(object):
else:
os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}")
# download from github release
if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
# download from fairseq
if "fairseq" in model_name:
self.download_fairseq_model(model_name, output_path)
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
"default_vocoder": None,
"author": "fairseq",
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
self.print_model_license(model_item=model_item)
# get model from models.json
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
# download from github release
if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
self.print_model_license(model_item=model_item)
# find downloaded files
output_model_path = output_path
output_config_path = None
if model != "tortoise-v2":
if model != "tortoise-v2" and "fairseq" not in model_name:
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json
self._update_paths(output_path, output_config_path)
@ -421,6 +441,39 @@ class ModelManager(object):
# remove the extracted folder
rmtree(os.path.join(output_folder, z.namelist()[0]))
@staticmethod
def _download_tar_file(file_url, output_folder, progress_bar):
"""Download the github releases"""
# download the file
r = requests.get(file_url, stream=True)
# extract the file
try:
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
file.write(data)
with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder)
tar_names = t.getnames()
os.remove(temp_tar_name) # delete tar after extract
except tarfile.ReadError:
print(f" > Error: Bad tar file - {file_url}")
raise tarfile.ReadError # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
src_path = os.path.join(output_folder, tar_names[0], file_path)
dst_path = os.path.join(output_folder, os.path.basename(file_path))
if src_path != dst_path:
copyfile(src_path, dst_path)
# remove the extracted folder
rmtree(os.path.join(output_folder, tar_names[0]))
@staticmethod
def _download_model_files(file_urls, output_folder, progress_bar):
"""Download the github releases"""

View File

@ -7,7 +7,9 @@ import pysbd
import torch
from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.vits import Vits
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
@ -98,8 +100,12 @@ class Synthesizer(object):
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod
def _get_segmenter(lang: str):
@ -133,12 +139,23 @@ class Synthesizer(object):
if use_cuda:
self.vc_model.cuda()
def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the fairseq model from a directory.
We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
self.tts_config = VitsConfig()
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config , checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
if use_cuda:
self.tts_model.cuda()
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory.
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
self.tts_model = setup_tts_model(config)

View File

@ -128,7 +128,7 @@ wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0],
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
```
Here is an example for a single speaker model.
#### Here is an example for a single speaker model.
```python
# Init TTS with the target model name
@ -137,7 +137,7 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False,
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
```
Example voice cloning with YourTTS in English, French and Portuguese:
#### Example voice cloning with YourTTS in English, French and Portuguese:
```python
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
@ -146,15 +146,16 @@ tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wa
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
```
Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
#### Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
```python
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
```
Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can
clone voices by using any model in 🐸TTS.
#### Example voice cloning by a single speaker TTS model combining with the voice conversion model.
This way, you can clone voices by using any model in 🐸TTS.
```python
tts = TTS("tts_models/de/thorsten/tacotron2-DDC")
@ -163,8 +164,11 @@ tts.tts_with_vc_to_file(
speaker_wav="target/speaker.wav",
file_path="ouptut.wav"
)
```
Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio.
#### Example text to speech using [🐸Coqui Studio](https://coqui.ai) models.
You can use all of your available speakers in the studio.
[🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.
@ -193,4 +197,23 @@ api.emotions
api.list_speakers()
api.list_voices()
wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5)
```
#### Example text to speech using **Fairseq models in ~1100 languages** 🤯.
For these models use the following name format: `tts_models/<lang-iso_code>/fairseq/vits`.
You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).
```python
from TTS.api import TTS
api = TTS(model_name="tts_models/eng/fairseq/vits", gpu=True)
api.tts_to_file("This is a test.", file_path="output.wav")
# TTS with on the fly voice conversion
api = TTS("tts_models/deu/fairseq/vits")
api.tts_with_vc_to_file(
"Wie sage ich auf Italienisch, dass ich dich liebe?",
speaker_wav="target/speaker.wav",
file_path="ouptut.wav"
)
```