Port Fairseq TTS models (#2628)

* Load fairseq models

* Add docs and missing files

* Managing fairseq models and docs for API

* Make style

* Use scarf URL

* Add tests

* Fix URL

* Pass cpu

* Make lint

* Fixup

* Make lint

* fixup

* Fixup

* Change tokenization order

* Update README

* Fixup

* Fixup
This commit is contained in:
Eren Gölge 2023-06-05 11:15:13 +02:00 committed by GitHub
parent 0d5e68a09f
commit e785d101a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 314 additions and 38 deletions

View File

@ -1,10 +1,13 @@
## 🐸Coqui.ai News ## 🐸Coqui.ai News
- 📣 Coqui Studio API is landed on 🐸TTS. You can use the studio voices in combination with 🐸TTS models. [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api) - 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
- 📣 Voice generation with prompts - **Prompt to Voice** - is live on Coqui.ai!! [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice) - 📣 🐸TTS now supports 🐢Tortoise with faster inference.
- 📣 Clone your voice with a single click on [🐸Coqui.ai](https://app.coqui.ai/auth/signin) - 📣 **Coqui Studio API** is landed on 🐸TTS. - [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api)
<br> - 📣 [**Coqui Sudio API**](https://docs.coqui.ai/docs) is live.
- 📣 Voice generation with prompts - **Prompt to Voice** - is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin)!! - [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice)
- 📣 Voice generation with fusion - **Voice fusion** - is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin).
- 📣 Voice cloning is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin).
## <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/> ## <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>
@ -185,7 +188,9 @@ from TTS.api import TTS
model_name = TTS.list_models()[0] model_name = TTS.list_models()[0]
# Init TTS # Init TTS
tts = TTS(model_name) tts = TTS(model_name)
# Run TTS # Run TTS
# ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language # ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
# Text to speech with a numpy output # Text to speech with a numpy output
wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0]) wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
@ -199,7 +204,8 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False,
# Run TTS # Run TTS
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH) tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
# Example voice cloning with YourTTS in English, French and Portuguese: # Example voice cloning with YourTTS in English, French and Portuguese
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav") tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr-fr", file_path="output.wav") tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr-fr", file_path="output.wav")
@ -221,7 +227,9 @@ tts.tts_with_vc_to_file(
file_path="ouptut.wav" file_path="ouptut.wav"
) )
# Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio. # Example text to speech using [🐸Coqui Studio](https://coqui.ai) models.
# You can use all of your available speakers in the studio.
# [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account). # [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
# You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token. # You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.
@ -234,6 +242,20 @@ tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_b
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH) tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)
# Run TTS with emotion and speed control # Run TTS with emotion and speed control
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH, emotion="Happy", speed=1.5) tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH, emotion="Happy", speed=1.5)
#Example text to speech using **Fairseq models in ~1100 languages** 🤯.
#For these models use the following name format: `tts_models/<lang-iso_code>/fairseq/vits`.
#You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).
# TTS with on the fly voice conversion
api = TTS("tts_models/deu/fairseq/vits")
api.tts_with_vc_to_file(
"Wie sage ich auf Italienisch, dass ich dich liebe?",
speaker_wav="target/speaker.wav",
file_path="ouptut.wav"
)
``` ```
### Command line `tts` ### Command line `tts`

View File

@ -130,7 +130,7 @@ class CS_API:
for speaker in self.speakers: for speaker in self.speakers:
if speaker.name == name: if speaker.name == name:
return speaker return speaker
raise ValueError(f"Speaker {name} not found.") raise ValueError(f"Speaker {name} not found in {self.speakers}")
def id_to_speaker(self, speaker_id): def id_to_speaker(self, speaker_id):
for speaker in self.speakers: for speaker in self.speakers:
@ -264,6 +264,10 @@ class TTS:
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html):
>>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True)
>>> tts.tts_to_file("This is a test.", file_path="output.wav")
Args: Args:
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
model_path (str, optional): Path to the model checkpoint. Defaults to None. model_path (str, optional): Path to the model checkpoint. Defaults to None.
@ -342,7 +346,7 @@ class TTS:
def download_model_by_name(self, model_name: str): def download_model_by_name(self, model_name: str):
model_path, config_path, model_item = self.manager.download_model(model_name) model_path, config_path, model_item = self.manager.download_model(model_name)
if isinstance(model_item["github_rls_url"], list): if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)):
# return model directory if there are multiple files # return model directory if there are multiple files
# we assume that the model knows how to load itself # we assume that the model knows how to load itself
return None, None, None, None, model_path return None, None, None, None, model_path

View File

@ -25,11 +25,12 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -1723,6 +1724,50 @@ class Vits(BaseTTS):
self.eval() self.eval()
assert not self.training assert not self.training
def load_fairseq_checkpoint(
self, config, checkpoint_dir, eval=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
Performs some changes for compatibility.
Args:
config (Coqpit): 🐸TTS model config.
checkpoint_dir (str): Path to the checkpoint directory.
eval (bool, optional): Set to True for evaluation. Defaults to False.
"""
import json
from TTS.tts.utils.text.cleaners import basic_cleaners
self.disc = None
# set paths
config_file = os.path.join(checkpoint_dir, "config.json")
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
vocab_file = os.path.join(checkpoint_dir, "vocab.txt")
# set config params
with open(config_file, "r", encoding="utf-8") as file:
# Load the JSON data as a dictionary
config_org = json.load(file)
self.config.audio.sample_rate = config_org["data"]["sampling_rate"]
# self.config.add_blank = config['add_blank']
# set tokenizer
vocab = FairseqVocab(vocab_file)
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
self.tokenizer = TTSTokenizer(
use_phonemes=False,
text_cleaner=basic_cleaners,
characters=vocab,
phonemizer=None,
add_blank=config_org["data"]["add_blank"],
use_eos_bos=False,
)
# load fairseq checkpoint
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
self.load_state_dict(new_chk)
if eval:
self.eval()
assert not self.training
@staticmethod @staticmethod
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
"""Initiate model from config """Initiate model from config
@ -1919,3 +1964,24 @@ class VitsCharacters(BaseCharacters):
is_unique=False, is_unique=False,
is_sorted=True, is_sorted=True,
) )
class FairseqVocab(BaseVocabulary):
def __init__(self, vocab: str):
super(FairseqVocab).__init__()
self.vocab = vocab
@property
def vocab(self):
"""Return the vocabulary dictionary."""
return self._vocab
@vocab.setter
def vocab(self, vocab_file):
with open(vocab_file, encoding="utf-8") as f:
self._vocab = [x.replace("\n", "") for x in f.readlines()]
self.blank = self._vocab[0]
print(self._vocab)
self.pad = " "
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension

48
TTS/tts/utils/fairseq.py Normal file
View File

@ -0,0 +1,48 @@
import torch
def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"]
new_chk = {}
for k, v in chk.items():
if "enc_p." in k:
new_chk[k.replace("enc_p.", "text_encoder.")] = v
elif "dec." in k:
new_chk[k.replace("dec.", "waveform_decoder.")] = v
elif "enc_q." in k:
new_chk[k.replace("enc_q.", "posterior_encoder.")] = v
elif "flow.flows.2." in k:
new_chk[k.replace("flow.flows.2.", "flow.flows.1.")] = v
elif "flow.flows.4." in k:
new_chk[k.replace("flow.flows.4.", "flow.flows.2.")] = v
elif "flow.flows.6." in k:
new_chk[k.replace("flow.flows.6.", "flow.flows.3.")] = v
elif "dp.flows.0.m" in k:
new_chk[k.replace("dp.flows.0.m", "duration_predictor.flows.0.translation")] = v
elif "dp.flows.0.logs" in k:
new_chk[k.replace("dp.flows.0.logs", "duration_predictor.flows.0.log_scale")] = v
elif "dp.flows.1" in k:
new_chk[k.replace("dp.flows.1", "duration_predictor.flows.1")] = v
elif "dp.flows.3" in k:
new_chk[k.replace("dp.flows.3", "duration_predictor.flows.2")] = v
elif "dp.flows.5" in k:
new_chk[k.replace("dp.flows.5", "duration_predictor.flows.3")] = v
elif "dp.flows.7" in k:
new_chk[k.replace("dp.flows.7", "duration_predictor.flows.4")] = v
elif "dp.post_flows.0.m" in k:
new_chk[k.replace("dp.post_flows.0.m", "duration_predictor.post_flows.0.translation")] = v
elif "dp.post_flows.0.logs" in k:
new_chk[k.replace("dp.post_flows.0.logs", "duration_predictor.post_flows.0.log_scale")] = v
elif "dp.post_flows.1" in k:
new_chk[k.replace("dp.post_flows.1", "duration_predictor.post_flows.1")] = v
elif "dp.post_flows.3" in k:
new_chk[k.replace("dp.post_flows.3", "duration_predictor.post_flows.2")] = v
elif "dp.post_flows.5" in k:
new_chk[k.replace("dp.post_flows.5", "duration_predictor.post_flows.3")] = v
elif "dp.post_flows.7" in k:
new_chk[k.replace("dp.post_flows.7", "duration_predictor.post_flows.4")] = v
elif "dp." in k:
new_chk[k.replace("dp.", "duration_predictor.")] = v
else:
new_chk[k] = v
return new_chk

View File

@ -63,6 +63,18 @@ class BaseVocabulary:
the vocabulary.""" the vocabulary."""
return self.char_to_id(self.blank) if self.blank else len(self.vocab) return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def bos_id(self) -> int:
"""Return the index of the bos character. If the bos character is not specified, return the length of the
vocabulary."""
return self.char_to_id(self.bos) if self.bos else len(self.vocab)
@property
def eos_id(self) -> int:
"""Return the index of the eos character. If the eos character is not specified, return the length of the
vocabulary."""
return self.char_to_id(self.eos) if self.eos else len(self.vocab)
@property @property
def vocab(self): def vocab(self):
"""Return the vocabulary dictionary.""" """Return the vocabulary dictionary."""
@ -71,11 +83,13 @@ class BaseVocabulary:
@vocab.setter @vocab.setter
def vocab(self, vocab): def vocab(self, vocab):
"""Set the vocabulary dictionary and character mapping dictionaries.""" """Set the vocabulary dictionary and character mapping dictionaries."""
self._vocab = vocab self._vocab, self._char_to_id, self._id_to_char = None, None, None
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} if vocab is not None:
self._id_to_char = { self._vocab = vocab
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
} self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
@staticmethod @staticmethod
def init_from_config(config, **kwargs): def init_from_config(config, **kwargs):
@ -93,6 +107,17 @@ class BaseVocabulary:
) )
return BaseVocabulary(**kwargs), config return BaseVocabulary(**kwargs), config
def to_config(self) -> "CharactersConfig":
return CharactersConfig(
vocab_dict=self._vocab,
pad=self.pad,
eos=self.eos,
bos=self.bos,
blank=self.blank,
is_unique=False,
is_sorted=False,
)
@property @property
def num_chars(self): def num_chars(self):
"""Return number of tokens in the vocabulary.""" """Return number of tokens in the vocabulary."""
@ -174,6 +199,14 @@ class BaseCharacters:
def blank_id(self) -> int: def blank_id(self) -> int:
return self.char_to_id(self.blank) if self.blank else len(self.vocab) return self.char_to_id(self.blank) if self.blank else len(self.vocab)
@property
def eos_id(self) -> int:
return self.char_to_id(self.eos) if self.eos else len(self.vocab)
@property
def bos_id(self) -> int:
return self.char_to_id(self.bos) if self.bos else len(self.vocab)
@property @property
def characters(self): def characters(self):
return self._characters return self._characters

View File

@ -108,11 +108,12 @@ class TTSTokenizer:
text = self.text_cleaner(text) text = self.text_cleaner(text)
if self.use_phonemes: if self.use_phonemes:
text = self.phonemizer.phonemize(text, separator="", language=language) text = self.phonemizer.phonemize(text, separator="", language=language)
text = self.encode(text)
if self.add_blank: if self.add_blank:
text = self.intersperse_blank_char(text, True) text = self.intersperse_blank_char(text, True)
if self.use_eos_bos: if self.use_eos_bos:
text = self.pad_with_bos_eos(text) text = self.pad_with_bos_eos(text)
return self.encode(text) return text
def ids_to_text(self, id_sequence: List[int]) -> str: def ids_to_text(self, id_sequence: List[int]) -> str:
"""Converts a sequence of token IDs to a string of text.""" """Converts a sequence of token IDs to a string of text."""
@ -120,14 +121,14 @@ class TTSTokenizer:
def pad_with_bos_eos(self, char_sequence: List[str]): def pad_with_bos_eos(self, char_sequence: List[str]):
"""Pads a sequence with the special BOS and EOS characters.""" """Pads a sequence with the special BOS and EOS characters."""
return [self.characters.bos] + list(char_sequence) + [self.characters.eos] return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id]
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
"""Intersperses the blank character between characters in a sequence. """Intersperses the blank character between characters in a sequence.
Use the ```blank``` character if defined else use the ```pad``` character. Use the ```blank``` character if defined else use the ```pad``` character.
""" """
char_to_use = self.characters.blank if use_blank_char else self.characters.pad char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad
result = [char_to_use] * (len(char_sequence) * 2 + 1) result = [char_to_use] * (len(char_sequence) * 2 + 1)
result[1::2] = char_sequence result[1::2] = char_sequence
return result return result

View File

@ -1,5 +1,6 @@
import json import json
import os import os
import tarfile
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from shutil import copyfile, rmtree from shutil import copyfile, rmtree
@ -245,6 +246,30 @@ class ModelManager(object):
else: else:
print(" > Model's license - No license information available") print(" > Model's license - No license information available")
def download_fairseq_model(self, model_name, output_path):
URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/"
_, lang, _, _ = model_name.split("/")
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
def _set_model_item(self, model_name):
# fetch model info from the dict
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
if "fairseq" in model_name:
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
"default_vocoder": None,
"author": "fairseq",
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
else:
# get model from models.json
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
return model_item, model_full_name, model
def download_model(self, model_name): def download_model(self, model_name):
"""Download model files given the full model name. """Download model files given the full model name.
Model name is in the format Model name is in the format
@ -259,11 +284,7 @@ class ModelManager(object):
Args: Args:
model_name (str): model name as explained above. model_name (str): model name as explained above.
""" """
# fetch model info from the dict model_item, model_full_name, model = self._set_model_item(model_name)
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type
# set the model specific output path # set the model specific output path
output_path = os.path.join(self.output_prefix, model_full_name) output_path = os.path.join(self.output_prefix, model_full_name)
if os.path.exists(output_path): if os.path.exists(output_path):
@ -271,16 +292,20 @@ class ModelManager(object):
else: else:
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}") print(f" > Downloading model to {output_path}")
# download from github release # download from fairseq
if isinstance(model_item["github_rls_url"], list): if "fairseq" in model_name:
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) self.download_fairseq_model(model_name, output_path)
else: else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) # download from github release
self.print_model_license(model_item=model_item) if isinstance(model_item["github_rls_url"], list):
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
else:
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
self.print_model_license(model_item=model_item)
# find downloaded files # find downloaded files
output_model_path = output_path output_model_path = output_path
output_config_path = None output_config_path = None
if model != "tortoise-v2": if model != "tortoise-v2" and "fairseq" not in model_name:
output_model_path, output_config_path = self._find_files(output_path) output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json # update paths in the config.json
self._update_paths(output_path, output_config_path) self._update_paths(output_path, output_config_path)
@ -421,6 +446,39 @@ class ModelManager(object):
# remove the extracted folder # remove the extracted folder
rmtree(os.path.join(output_folder, z.namelist()[0])) rmtree(os.path.join(output_folder, z.namelist()[0]))
@staticmethod
def _download_tar_file(file_url, output_folder, progress_bar):
"""Download the github releases"""
# download the file
r = requests.get(file_url, stream=True)
# extract the file
try:
total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_tar_name, "wb") as file:
for data in r.iter_content(block_size):
if progress_bar:
progress_bar.update(len(data))
file.write(data)
with tarfile.open(temp_tar_name) as t:
t.extractall(output_folder)
tar_names = t.getnames()
os.remove(temp_tar_name) # delete tar after extract
except tarfile.ReadError:
print(f" > Error: Bad tar file - {file_url}")
raise tarfile.ReadError # pylint: disable=raise-missing-from
# move the files to the outer path
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
src_path = os.path.join(output_folder, tar_names[0], file_path)
dst_path = os.path.join(output_folder, os.path.basename(file_path))
if src_path != dst_path:
copyfile(src_path, dst_path)
# remove the extracted folder
rmtree(os.path.join(output_folder, tar_names[0]))
@staticmethod @staticmethod
def _download_model_files(file_urls, output_folder, progress_bar): def _download_model_files(file_urls, output_folder, progress_bar):
"""Download the github releases""" """Download the github releases"""

View File

@ -7,7 +7,9 @@ import pysbd
import torch import torch
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.models import setup_model as setup_tts_model from TTS.tts.models import setup_model as setup_tts_model
from TTS.tts.models.vits import Vits
# pylint: disable=unused-wildcard-import # pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
@ -98,8 +100,12 @@ class Synthesizer(object):
self.output_sample_rate = self.vc_config.audio["output_sample_rate"] self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir: if model_dir:
self._load_tts_from_dir(model_dir, use_cuda) if "fairseq" in model_dir:
self.output_sample_rate = self.tts_config.audio["output_sample_rate"] self._load_fairseq_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
else:
self._load_tts_from_dir(model_dir, use_cuda)
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod @staticmethod
def _get_segmenter(lang: str): def _get_segmenter(lang: str):
@ -133,12 +139,23 @@ class Synthesizer(object):
if use_cuda: if use_cuda:
self.vc_model.cuda() self.vc_model.cuda()
def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the fairseq model from a directory.
We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory.
"""
self.tts_config = VitsConfig()
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
if use_cuda:
self.tts_model.cuda()
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory. """Load the TTS model from a directory.
We assume the model knows how to load itself from the directory and there is a config.json file in the directory. We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
""" """
config = load_config(os.path.join(model_dir, "config.json")) config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config self.tts_config = config
self.tts_model = setup_tts_model(config) self.tts_model = setup_tts_model(config)

View File

@ -128,7 +128,7 @@ wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0],
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav") tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
``` ```
Here is an example for a single speaker model. #### Here is an example for a single speaker model.
```python ```python
# Init TTS with the target model name # Init TTS with the target model name
@ -137,7 +137,7 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False,
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH) tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
``` ```
Example voice cloning with YourTTS in English, French and Portuguese: #### Example voice cloning with YourTTS in English, French and Portuguese:
```python ```python
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
@ -146,15 +146,16 @@ tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wa
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav") tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
``` ```
Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav` #### Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
```python ```python
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True) tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav") tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
``` ```
Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can #### Example voice cloning by a single speaker TTS model combining with the voice conversion model.
clone voices by using any model in 🐸TTS.
This way, you can clone voices by using any model in 🐸TTS.
```python ```python
tts = TTS("tts_models/de/thorsten/tacotron2-DDC") tts = TTS("tts_models/de/thorsten/tacotron2-DDC")
@ -163,8 +164,11 @@ tts.tts_with_vc_to_file(
speaker_wav="target/speaker.wav", speaker_wav="target/speaker.wav",
file_path="ouptut.wav" file_path="ouptut.wav"
) )
```
Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio. #### Example text to speech using [🐸Coqui Studio](https://coqui.ai) models.
You can use all of your available speakers in the studio.
[🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account). [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token. You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.
@ -194,3 +198,22 @@ api.list_speakers()
api.list_voices() api.list_voices()
wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5) wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5)
``` ```
#### Example text to speech using **Fairseq models in ~1100 languages** 🤯.
For these models use the following name format: `tts_models/<lang-iso_code>/fairseq/vits`.
You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).
```python
from TTS.api import TTS
api = TTS(model_name="tts_models/eng/fairseq/vits", gpu=True)
api.tts_to_file("This is a test.", file_path="output.wav")
# TTS with on the fly voice conversion
api = TTS("tts_models/deu/fairseq/vits")
api.tts_with_vc_to_file(
"Wie sage ich auf Italienisch, dass ich dich liebe?",
speaker_wav="target/speaker.wav",
file_path="ouptut.wav"
)
```

View File

@ -60,7 +60,7 @@ if is_coqui_available:
self.assertIsNone(tts.languages) self.assertIsNone(tts.languages)
def test_studio_model(self): def test_studio_model(self):
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio") tts = TTS(model_name="coqui_studio/en/Zacharie Aimilios/coqui_studio")
tts.tts_to_file(text="This is a test.") tts.tts_to_file(text="This is a test.")
# check speed > 2.0 raises error # check speed > 2.0 raises error
@ -83,6 +83,10 @@ if is_coqui_available:
wav = tts.tts(text="This is a test.", speed=2.0, emotion="Sad") wav = tts.tts(text="This is a test.", speed=2.0, emotion="Sad")
self.assertGreater(len(wav), 0) self.assertGreater(len(wav), 0)
def test_fairseq_model(self): # pylint: disable=no-self-use
tts = TTS(model_name="tts_models/eng/fairseq/vits")
tts.tts_to_file(text="This is a test.")
def test_multi_speaker_multi_lingual_model(self): def test_multi_speaker_multi_lingual_model(self):
tts = TTS() tts = TTS()
tts.load_tts_model_by_name(tts.models[0]) # YourTTS tts.load_tts_model_by_name(tts.models[0]) # YourTTS