Managing fairseq models and docs for API

This commit is contained in:
Eren G??lge 2023-05-23 17:46:16 +02:00
parent 92d4823ad4
commit bb46727733
5 changed files with 117 additions and 20 deletions

View File

@ -264,6 +264,10 @@ class TTS:
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html):
>>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True)
>>> tts.tts_to_file("This is a test.", file_path="output.wav")
Args: Args:
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
model_path (str, optional): Path to the model checkpoint. Defaults to None. model_path (str, optional): Path to the model checkpoint. Defaults to None.
@ -342,7 +346,7 @@ class TTS:
def download_model_by_name(self, model_name: str): def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name) model_path, config_path, model_item = self.manager.download_model(model_name)
if isinstance(model_item["github_rls_url"], list): if "fairseq" in model_name or isinstance(model_item["github_rls_url"], list):
# return model directory if there are multiple files # return model directory if there are multiple files
# we assume that the model knows how to load itself # we assume that the model knows how to load itself
return None, None, None, None, model_path return None, None, None, None, model_path

View File

@ -1726,7 +1726,6 @@ class Vits(BaseTTS):
def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False): def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False):
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
Performs some changes for compatibility. Performs some changes for compatibility.
Args: Args:
@ -1736,6 +1735,7 @@ class Vits(BaseTTS):
""" """
import json import json
self.disc = None
# set paths # set paths
config_file = os.path.join(checkpoint_dir, "config.json") config_file = os.path.join(checkpoint_dir, "config.json")
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
@ -1974,7 +1974,7 @@ class FairseqVocab(BaseVocabulary):
@vocab.setter @vocab.setter
def vocab(self, vocab_file): def vocab(self, vocab_file):
self._vocab = [x.replace("\n", "") for x in open(vocab_file).readlines()] self._vocab = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()]
self.blank = self._vocab[0] self.blank = self._vocab[0]
print(self._vocab) print(self._vocab)
self.pad = " " self.pad = " "

View File

@ -1,5 +1,6 @@
import json import json
import os import os
import tarfile
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
@ -245,6 +246,12 @@ class ModelManager(object):
else: else:
print(" > Model's license - No license information available") print(" > Model's license - No license information available")
def download_fairseq_model(self, model_name, output_path):
URI_PREFIX = "https://dl.fbaipublicfiles.com/mms/tts/"
model_type, lang, dataset, model = model_name.split("/")
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
def download_model(self, model_name): def download_model(self, model_name):
"""Download model files given the full model name. """Download model files given the full model name.
Model name is in the format Model name is in the format
@ -259,11 +266,10 @@ class ModelManager(object):
Args: Args:
model_name (str): model name as explained above. model_name (str): model name as explained above.
""" """
model_item = None
# fetch model info from the dict # fetch model info from the dict
model_type, lang, dataset, model = model_name.split("/") model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}" model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
# set the model specific output path # set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name) output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path): if os.path.exists(output_path):
@ -271,16 +277,30 @@ class ModelManager(object):
else: else:
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}") print(f" > Downloading model to {output_path}")
# download from github release # download from fairseq
if isinstance(model_item["github_rls_url"], list): if "fairseq" in model_name:
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) self.download_fairseq_model(model_name, output_path)
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
"default_vocoder": None,
"author": "fairseq",
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
else: else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) # get model from models.json
self.print_model_license(model_item=model_item) model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
# download from github release
if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
self.print_model_license(model_item=model_item)
# find downloaded files # find downloaded files
output_model_path = output_path output_model_path = output_path
output_config_path = None output_config_path = None
if model != "tortoise-v2": if model != "tortoise-v2" and "fairseq" not in model_name:
output_model_path, output_config_path = self._find_files(output_path) output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json # update paths in the config.json
self._update_paths(output_path, output_config_path) self._update_paths(output_path, output_config_path)
@ -421,6 +441,39 @@ class ModelManager(object):
# remove the extracted folder # remove the extracted folder
rmtree(os.path.join(output_folder, z.namelist()[0])) rmtree(os.path.join(output_folder, z.namelist()[0]))
@staticmethod
def _download_tar_file(file_url, output_folder, progress_bar):
"""Download the github releases"""
# download the file
r = requests.get(file_url, stream=True)
# extract the file
try:
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
file.write(data)
with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder)
tar_names = t.getnames()
os.remove(temp_tar_name) # delete tar after extract
except tarfile.ReadError:
print(f" > Error: Bad tar file - {file_url}")
raise tarfile.ReadError # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
src_path = os.path.join(output_folder, tar_names[0], file_path)
dst_path = os.path.join(output_folder, os.path.basename(file_path))
if src_path != dst_path:
copyfile(src_path, dst_path)
# remove the extracted folder
rmtree(os.path.join(output_folder, tar_names[0]))
@staticmethod @staticmethod
def _download_model_files(file_urls, output_folder, progress_bar): def _download_model_files(file_urls, output_folder, progress_bar):
"""Download the github releases""" """Download the github releases"""

View File

@ -7,7 +7,9 @@ import pysbd
import torch import torch
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.vits import Vits
# pylint: disable=unused-wildcard-import # pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
@ -98,8 +100,12 @@ class Synthesizer(object):
self.output_sample_rate = self.vc_config.audio["output_sample_rate"] self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir: if model_dir:
self._load_tts_from_dir(model_dir, use_cuda) if "fairseq" in model_dir:
self.output_sample_rate = self.tts_config.audio["output_sample_rate"] self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod @staticmethod
def _get_segmenter(lang: str): def _get_segmenter(lang: str):
@ -133,12 +139,23 @@ class Synthesizer(object):
if use_cuda: if use_cuda:
self.vc_model.cuda() self.vc_model.cuda()
def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the fairseq model from a directory.
We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
self.tts_config = VitsConfig()
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config , checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
if use_cuda:
self.tts_model.cuda()
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory. """Load the TTS model from a directory.
We assume the model knows how to load itself from the directory and there is a config.json file in the directory. We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
""" """
config = load_config(os.path.join(model_dir, "config.json")) config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config self.tts_config = config
self.tts_model = setup_tts_model(config) self.tts_model = setup_tts_model(config)

View File

@ -128,7 +128,7 @@ wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0],
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav") tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
``` ```
Here is an example for a single speaker model. #### Here is an example for a single speaker model.
```python ```python
# Init TTS with the target model name # Init TTS with the target model name
@ -137,7 +137,7 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False,
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH) tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
``` ```
Example voice cloning with YourTTS in English, French and Portuguese: #### Example voice cloning with YourTTS in English, French and Portuguese:
```python ```python
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
@ -146,15 +146,16 @@ tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wa
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav") tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
``` ```
Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav` #### Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
```python ```python
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True) tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav") tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
``` ```
Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can #### Example voice cloning by a single speaker TTS model combining with the voice conversion model.
clone voices by using any model in 🐸TTS.
This way, you can clone voices by using any model in 🐸TTS.
```python ```python
tts = TTS("tts_models/de/thorsten/tacotron2-DDC") tts = TTS("tts_models/de/thorsten/tacotron2-DDC")
@ -163,8 +164,11 @@ tts.tts_with_vc_to_file(
speaker_wav="target/speaker.wav", speaker_wav="target/speaker.wav",
file_path="ouptut.wav" file_path="ouptut.wav"
) )
```
Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio. #### Example text to speech using [🐸Coqui Studio](https://coqui.ai) models.
You can use all of your available speakers in the studio.
[🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account). [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token. You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.
@ -193,4 +197,23 @@ api.emotions
api.list_speakers() api.list_speakers()
api.list_voices() api.list_voices()
wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5) wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5)
```
#### Example text to speech using **Fairseq models in ~1100 languages** 🤯.
For these models use the following name format: `tts_models/<lang-iso_code>/fairseq/vits`.
You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).
```python
from TTS.api import TTS
api = TTS(model_name="tts_models/eng/fairseq/vits", gpu=True)
api.tts_to_file("This is a test.", file_path="output.wav")
# TTS with on the fly voice conversion
api = TTS("tts_models/deu/fairseq/vits")
api.tts_with_vc_to_file(
"Wie sage ich auf Italienisch, dass ich dich liebe?",
speaker_wav="target/speaker.wav",
file_path="ouptut.wav"
)
``` ```