mirror of https://github.com/coqui-ai/TTS.git
Port Fairseq TTS models (#2628)
* Load fairseq models * Add docs and missing files * Managing fairseq models and docs for API * Make style * Use scarf URL * Add tests * Fix URL * Pass cpu * Make lint * Fixup * Make lint * fixup * Fixup * Change tokenization order * Update README * Fixup * Fixup
This commit is contained in:
parent
0d5e68a09f
commit
e785d101a1
34
README.md
34
README.md
|
@ -1,10 +1,13 @@
|
||||||
|
|
||||||
|
|
||||||
## 🐸Coqui.ai News
|
## 🐸Coqui.ai News
|
||||||
- 📣 Coqui Studio API is landed on 🐸TTS. You can use the studio voices in combination with 🐸TTS models. [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api)
|
- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
|
||||||
- 📣 Voice generation with prompts - **Prompt to Voice** - is live on Coqui.ai!! [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice)
|
- 📣 🐸TTS now supports 🐢Tortoise with faster inference.
|
||||||
- 📣 Clone your voice with a single click on [🐸Coqui.ai](https://app.coqui.ai/auth/signin)
|
- 📣 **Coqui Studio API** is landed on 🐸TTS. - [Example](https://github.com/coqui-ai/TTS/blob/dev/README.md#-python-api)
|
||||||
<br>
|
- 📣 [**Coqui Sudio API**](https://docs.coqui.ai/docs) is live.
|
||||||
|
- 📣 Voice generation with prompts - **Prompt to Voice** - is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin)!! - [Blog Post](https://coqui.ai/blog/tts/prompt-to-voice)
|
||||||
|
- 📣 Voice generation with fusion - **Voice fusion** - is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin).
|
||||||
|
- 📣 Voice cloning is live on [**Coqui Studio**](https://app.coqui.ai/auth/signin).
|
||||||
|
|
||||||
## <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>
|
## <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/coqui-log-green-TTS.png" height="56"/>
|
||||||
|
|
||||||
|
@ -185,7 +188,9 @@ from TTS.api import TTS
|
||||||
model_name = TTS.list_models()[0]
|
model_name = TTS.list_models()[0]
|
||||||
# Init TTS
|
# Init TTS
|
||||||
tts = TTS(model_name)
|
tts = TTS(model_name)
|
||||||
|
|
||||||
# Run TTS
|
# Run TTS
|
||||||
|
|
||||||
# ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
|
# ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
|
||||||
# Text to speech with a numpy output
|
# Text to speech with a numpy output
|
||||||
wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
|
wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
|
||||||
|
@ -199,7 +204,8 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False,
|
||||||
# Run TTS
|
# Run TTS
|
||||||
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
|
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
|
||||||
|
|
||||||
# Example voice cloning with YourTTS in English, French and Portuguese:
|
# Example voice cloning with YourTTS in English, French and Portuguese
|
||||||
|
|
||||||
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
|
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
|
||||||
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
||||||
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr-fr", file_path="output.wav")
|
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr-fr", file_path="output.wav")
|
||||||
|
@ -221,7 +227,9 @@ tts.tts_with_vc_to_file(
|
||||||
file_path="ouptut.wav"
|
file_path="ouptut.wav"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio.
|
# Example text to speech using [🐸Coqui Studio](https://coqui.ai) models.
|
||||||
|
|
||||||
|
# You can use all of your available speakers in the studio.
|
||||||
# [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
|
# [🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
|
||||||
# You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.
|
# You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.
|
||||||
|
|
||||||
|
@ -234,6 +242,20 @@ tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio", progress_b
|
||||||
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)
|
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH)
|
||||||
# Run TTS with emotion and speed control
|
# Run TTS with emotion and speed control
|
||||||
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH, emotion="Happy", speed=1.5)
|
tts.tts_to_file(text="This is a test.", file_path=OUTPUT_PATH, emotion="Happy", speed=1.5)
|
||||||
|
|
||||||
|
|
||||||
|
#Example text to speech using **Fairseq models in ~1100 languages** 🤯.
|
||||||
|
|
||||||
|
#For these models use the following name format: `tts_models/<lang-iso_code>/fairseq/vits`.
|
||||||
|
#You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).
|
||||||
|
|
||||||
|
# TTS with on the fly voice conversion
|
||||||
|
api = TTS("tts_models/deu/fairseq/vits")
|
||||||
|
api.tts_with_vc_to_file(
|
||||||
|
"Wie sage ich auf Italienisch, dass ich dich liebe?",
|
||||||
|
speaker_wav="target/speaker.wav",
|
||||||
|
file_path="ouptut.wav"
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Command line `tts`
|
### Command line `tts`
|
||||||
|
|
|
@ -130,7 +130,7 @@ class CS_API:
|
||||||
for speaker in self.speakers:
|
for speaker in self.speakers:
|
||||||
if speaker.name == name:
|
if speaker.name == name:
|
||||||
return speaker
|
return speaker
|
||||||
raise ValueError(f"Speaker {name} not found.")
|
raise ValueError(f"Speaker {name} not found in {self.speakers}")
|
||||||
|
|
||||||
def id_to_speaker(self, speaker_id):
|
def id_to_speaker(self, speaker_id):
|
||||||
for speaker in self.speakers:
|
for speaker in self.speakers:
|
||||||
|
@ -264,6 +264,10 @@ class TTS:
|
||||||
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
|
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
|
||||||
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
|
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
|
||||||
|
|
||||||
|
Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html):
|
||||||
|
>>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True)
|
||||||
|
>>> tts.tts_to_file("This is a test.", file_path="output.wav")
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
|
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
|
||||||
model_path (str, optional): Path to the model checkpoint. Defaults to None.
|
model_path (str, optional): Path to the model checkpoint. Defaults to None.
|
||||||
|
@ -342,7 +346,7 @@ class TTS:
|
||||||
|
|
||||||
def download_model_by_name(self, model_name: str):
|
def download_model_by_name(self, model_name: str):
|
||||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||||
if isinstance(model_item["github_rls_url"], list):
|
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["github_rls_url"], list)):
|
||||||
# return model directory if there are multiple files
|
# return model directory if there are multiple files
|
||||||
# we assume that the model knows how to load itself
|
# we assume that the model knows how to load itself
|
||||||
return None, None, None, None, model_path
|
return None, None, None, None, model_path
|
||||||
|
|
|
@ -25,11 +25,12 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint
|
||||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||||
from TTS.tts.utils.languages import LanguageManager
|
from TTS.tts.utils.languages import LanguageManager
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
@ -1723,6 +1724,50 @@ class Vits(BaseTTS):
|
||||||
self.eval()
|
self.eval()
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
|
def load_fairseq_checkpoint(
|
||||||
|
self, config, checkpoint_dir, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
|
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
|
||||||
|
Performs some changes for compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): 🐸TTS model config.
|
||||||
|
checkpoint_dir (str): Path to the checkpoint directory.
|
||||||
|
eval (bool, optional): Set to True for evaluation. Defaults to False.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from TTS.tts.utils.text.cleaners import basic_cleaners
|
||||||
|
|
||||||
|
self.disc = None
|
||||||
|
# set paths
|
||||||
|
config_file = os.path.join(checkpoint_dir, "config.json")
|
||||||
|
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
|
||||||
|
vocab_file = os.path.join(checkpoint_dir, "vocab.txt")
|
||||||
|
# set config params
|
||||||
|
with open(config_file, "r", encoding="utf-8") as file:
|
||||||
|
# Load the JSON data as a dictionary
|
||||||
|
config_org = json.load(file)
|
||||||
|
self.config.audio.sample_rate = config_org["data"]["sampling_rate"]
|
||||||
|
# self.config.add_blank = config['add_blank']
|
||||||
|
# set tokenizer
|
||||||
|
vocab = FairseqVocab(vocab_file)
|
||||||
|
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
|
||||||
|
self.tokenizer = TTSTokenizer(
|
||||||
|
use_phonemes=False,
|
||||||
|
text_cleaner=basic_cleaners,
|
||||||
|
characters=vocab,
|
||||||
|
phonemizer=None,
|
||||||
|
add_blank=config_org["data"]["add_blank"],
|
||||||
|
use_eos_bos=False,
|
||||||
|
)
|
||||||
|
# load fairseq checkpoint
|
||||||
|
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
|
||||||
|
self.load_state_dict(new_chk)
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||||
"""Initiate model from config
|
"""Initiate model from config
|
||||||
|
@ -1919,3 +1964,24 @@ class VitsCharacters(BaseCharacters):
|
||||||
is_unique=False,
|
is_unique=False,
|
||||||
is_sorted=True,
|
is_sorted=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FairseqVocab(BaseVocabulary):
|
||||||
|
def __init__(self, vocab: str):
|
||||||
|
super(FairseqVocab).__init__()
|
||||||
|
self.vocab = vocab
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab(self):
|
||||||
|
"""Return the vocabulary dictionary."""
|
||||||
|
return self._vocab
|
||||||
|
|
||||||
|
@vocab.setter
|
||||||
|
def vocab(self, vocab_file):
|
||||||
|
with open(vocab_file, encoding="utf-8") as f:
|
||||||
|
self._vocab = [x.replace("\n", "") for x in f.readlines()]
|
||||||
|
self.blank = self._vocab[0]
|
||||||
|
print(self._vocab)
|
||||||
|
self.pad = " "
|
||||||
|
self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
||||||
|
self._id_to_char = {i: s for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def rehash_fairseq_vits_checkpoint(checkpoint_file):
|
||||||
|
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"]
|
||||||
|
new_chk = {}
|
||||||
|
for k, v in chk.items():
|
||||||
|
if "enc_p." in k:
|
||||||
|
new_chk[k.replace("enc_p.", "text_encoder.")] = v
|
||||||
|
elif "dec." in k:
|
||||||
|
new_chk[k.replace("dec.", "waveform_decoder.")] = v
|
||||||
|
elif "enc_q." in k:
|
||||||
|
new_chk[k.replace("enc_q.", "posterior_encoder.")] = v
|
||||||
|
elif "flow.flows.2." in k:
|
||||||
|
new_chk[k.replace("flow.flows.2.", "flow.flows.1.")] = v
|
||||||
|
elif "flow.flows.4." in k:
|
||||||
|
new_chk[k.replace("flow.flows.4.", "flow.flows.2.")] = v
|
||||||
|
elif "flow.flows.6." in k:
|
||||||
|
new_chk[k.replace("flow.flows.6.", "flow.flows.3.")] = v
|
||||||
|
elif "dp.flows.0.m" in k:
|
||||||
|
new_chk[k.replace("dp.flows.0.m", "duration_predictor.flows.0.translation")] = v
|
||||||
|
elif "dp.flows.0.logs" in k:
|
||||||
|
new_chk[k.replace("dp.flows.0.logs", "duration_predictor.flows.0.log_scale")] = v
|
||||||
|
elif "dp.flows.1" in k:
|
||||||
|
new_chk[k.replace("dp.flows.1", "duration_predictor.flows.1")] = v
|
||||||
|
elif "dp.flows.3" in k:
|
||||||
|
new_chk[k.replace("dp.flows.3", "duration_predictor.flows.2")] = v
|
||||||
|
elif "dp.flows.5" in k:
|
||||||
|
new_chk[k.replace("dp.flows.5", "duration_predictor.flows.3")] = v
|
||||||
|
elif "dp.flows.7" in k:
|
||||||
|
new_chk[k.replace("dp.flows.7", "duration_predictor.flows.4")] = v
|
||||||
|
elif "dp.post_flows.0.m" in k:
|
||||||
|
new_chk[k.replace("dp.post_flows.0.m", "duration_predictor.post_flows.0.translation")] = v
|
||||||
|
elif "dp.post_flows.0.logs" in k:
|
||||||
|
new_chk[k.replace("dp.post_flows.0.logs", "duration_predictor.post_flows.0.log_scale")] = v
|
||||||
|
elif "dp.post_flows.1" in k:
|
||||||
|
new_chk[k.replace("dp.post_flows.1", "duration_predictor.post_flows.1")] = v
|
||||||
|
elif "dp.post_flows.3" in k:
|
||||||
|
new_chk[k.replace("dp.post_flows.3", "duration_predictor.post_flows.2")] = v
|
||||||
|
elif "dp.post_flows.5" in k:
|
||||||
|
new_chk[k.replace("dp.post_flows.5", "duration_predictor.post_flows.3")] = v
|
||||||
|
elif "dp.post_flows.7" in k:
|
||||||
|
new_chk[k.replace("dp.post_flows.7", "duration_predictor.post_flows.4")] = v
|
||||||
|
elif "dp." in k:
|
||||||
|
new_chk[k.replace("dp.", "duration_predictor.")] = v
|
||||||
|
else:
|
||||||
|
new_chk[k] = v
|
||||||
|
return new_chk
|
|
@ -63,6 +63,18 @@ class BaseVocabulary:
|
||||||
the vocabulary."""
|
the vocabulary."""
|
||||||
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_id(self) -> int:
|
||||||
|
"""Return the index of the bos character. If the bos character is not specified, return the length of the
|
||||||
|
vocabulary."""
|
||||||
|
return self.char_to_id(self.bos) if self.bos else len(self.vocab)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self) -> int:
|
||||||
|
"""Return the index of the eos character. If the eos character is not specified, return the length of the
|
||||||
|
vocabulary."""
|
||||||
|
return self.char_to_id(self.eos) if self.eos else len(self.vocab)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab(self):
|
def vocab(self):
|
||||||
"""Return the vocabulary dictionary."""
|
"""Return the vocabulary dictionary."""
|
||||||
|
@ -71,11 +83,13 @@ class BaseVocabulary:
|
||||||
@vocab.setter
|
@vocab.setter
|
||||||
def vocab(self, vocab):
|
def vocab(self, vocab):
|
||||||
"""Set the vocabulary dictionary and character mapping dictionaries."""
|
"""Set the vocabulary dictionary and character mapping dictionaries."""
|
||||||
self._vocab = vocab
|
self._vocab, self._char_to_id, self._id_to_char = None, None, None
|
||||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
if vocab is not None:
|
||||||
self._id_to_char = {
|
self._vocab = vocab
|
||||||
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||||
}
|
self._id_to_char = {
|
||||||
|
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config, **kwargs):
|
def init_from_config(config, **kwargs):
|
||||||
|
@ -93,6 +107,17 @@ class BaseVocabulary:
|
||||||
)
|
)
|
||||||
return BaseVocabulary(**kwargs), config
|
return BaseVocabulary(**kwargs), config
|
||||||
|
|
||||||
|
def to_config(self) -> "CharactersConfig":
|
||||||
|
return CharactersConfig(
|
||||||
|
vocab_dict=self._vocab,
|
||||||
|
pad=self.pad,
|
||||||
|
eos=self.eos,
|
||||||
|
bos=self.bos,
|
||||||
|
blank=self.blank,
|
||||||
|
is_unique=False,
|
||||||
|
is_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_chars(self):
|
def num_chars(self):
|
||||||
"""Return number of tokens in the vocabulary."""
|
"""Return number of tokens in the vocabulary."""
|
||||||
|
@ -174,6 +199,14 @@ class BaseCharacters:
|
||||||
def blank_id(self) -> int:
|
def blank_id(self) -> int:
|
||||||
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
return self.char_to_id(self.blank) if self.blank else len(self.vocab)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_id(self) -> int:
|
||||||
|
return self.char_to_id(self.eos) if self.eos else len(self.vocab)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_id(self) -> int:
|
||||||
|
return self.char_to_id(self.bos) if self.bos else len(self.vocab)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def characters(self):
|
def characters(self):
|
||||||
return self._characters
|
return self._characters
|
||||||
|
|
|
@ -108,11 +108,12 @@ class TTSTokenizer:
|
||||||
text = self.text_cleaner(text)
|
text = self.text_cleaner(text)
|
||||||
if self.use_phonemes:
|
if self.use_phonemes:
|
||||||
text = self.phonemizer.phonemize(text, separator="", language=language)
|
text = self.phonemizer.phonemize(text, separator="", language=language)
|
||||||
|
text = self.encode(text)
|
||||||
if self.add_blank:
|
if self.add_blank:
|
||||||
text = self.intersperse_blank_char(text, True)
|
text = self.intersperse_blank_char(text, True)
|
||||||
if self.use_eos_bos:
|
if self.use_eos_bos:
|
||||||
text = self.pad_with_bos_eos(text)
|
text = self.pad_with_bos_eos(text)
|
||||||
return self.encode(text)
|
return text
|
||||||
|
|
||||||
def ids_to_text(self, id_sequence: List[int]) -> str:
|
def ids_to_text(self, id_sequence: List[int]) -> str:
|
||||||
"""Converts a sequence of token IDs to a string of text."""
|
"""Converts a sequence of token IDs to a string of text."""
|
||||||
|
@ -120,14 +121,14 @@ class TTSTokenizer:
|
||||||
|
|
||||||
def pad_with_bos_eos(self, char_sequence: List[str]):
|
def pad_with_bos_eos(self, char_sequence: List[str]):
|
||||||
"""Pads a sequence with the special BOS and EOS characters."""
|
"""Pads a sequence with the special BOS and EOS characters."""
|
||||||
return [self.characters.bos] + list(char_sequence) + [self.characters.eos]
|
return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id]
|
||||||
|
|
||||||
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
|
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
|
||||||
"""Intersperses the blank character between characters in a sequence.
|
"""Intersperses the blank character between characters in a sequence.
|
||||||
|
|
||||||
Use the ```blank``` character if defined else use the ```pad``` character.
|
Use the ```blank``` character if defined else use the ```pad``` character.
|
||||||
"""
|
"""
|
||||||
char_to_use = self.characters.blank if use_blank_char else self.characters.pad
|
char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad
|
||||||
result = [char_to_use] * (len(char_sequence) * 2 + 1)
|
result = [char_to_use] * (len(char_sequence) * 2 + 1)
|
||||||
result[1::2] = char_sequence
|
result[1::2] = char_sequence
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import tarfile
|
||||||
import zipfile
|
import zipfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile, rmtree
|
from shutil import copyfile, rmtree
|
||||||
|
@ -245,6 +246,30 @@ class ModelManager(object):
|
||||||
else:
|
else:
|
||||||
print(" > Model's license - No license information available")
|
print(" > Model's license - No license information available")
|
||||||
|
|
||||||
|
def download_fairseq_model(self, model_name, output_path):
|
||||||
|
URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/"
|
||||||
|
_, lang, _, _ = model_name.split("/")
|
||||||
|
model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
|
||||||
|
self._download_tar_file(model_download_uri, output_path, self.progress_bar)
|
||||||
|
|
||||||
|
def _set_model_item(self, model_name):
|
||||||
|
# fetch model info from the dict
|
||||||
|
model_type, lang, dataset, model = model_name.split("/")
|
||||||
|
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||||
|
if "fairseq" in model_name:
|
||||||
|
model_item = {
|
||||||
|
"model_type": "tts_models",
|
||||||
|
"license": "CC BY-NC 4.0",
|
||||||
|
"default_vocoder": None,
|
||||||
|
"author": "fairseq",
|
||||||
|
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# get model from models.json
|
||||||
|
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||||
|
model_item["model_type"] = model_type
|
||||||
|
return model_item, model_full_name, model
|
||||||
|
|
||||||
def download_model(self, model_name):
|
def download_model(self, model_name):
|
||||||
"""Download model files given the full model name.
|
"""Download model files given the full model name.
|
||||||
Model name is in the format
|
Model name is in the format
|
||||||
|
@ -259,11 +284,7 @@ class ModelManager(object):
|
||||||
Args:
|
Args:
|
||||||
model_name (str): model name as explained above.
|
model_name (str): model name as explained above.
|
||||||
"""
|
"""
|
||||||
# fetch model info from the dict
|
model_item, model_full_name, model = self._set_model_item(model_name)
|
||||||
model_type, lang, dataset, model = model_name.split("/")
|
|
||||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
|
||||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
|
||||||
model_item["model_type"] = model_type
|
|
||||||
# set the model specific output path
|
# set the model specific output path
|
||||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(output_path):
|
||||||
|
@ -271,16 +292,20 @@ class ModelManager(object):
|
||||||
else:
|
else:
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
print(f" > Downloading model to {output_path}")
|
print(f" > Downloading model to {output_path}")
|
||||||
# download from github release
|
# download from fairseq
|
||||||
if isinstance(model_item["github_rls_url"], list):
|
if "fairseq" in model_name:
|
||||||
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
self.download_fairseq_model(model_name, output_path)
|
||||||
else:
|
else:
|
||||||
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
# download from github release
|
||||||
self.print_model_license(model_item=model_item)
|
if isinstance(model_item["github_rls_url"], list):
|
||||||
|
self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||||
|
else:
|
||||||
|
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||||
|
self.print_model_license(model_item=model_item)
|
||||||
# find downloaded files
|
# find downloaded files
|
||||||
output_model_path = output_path
|
output_model_path = output_path
|
||||||
output_config_path = None
|
output_config_path = None
|
||||||
if model != "tortoise-v2":
|
if model != "tortoise-v2" and "fairseq" not in model_name:
|
||||||
output_model_path, output_config_path = self._find_files(output_path)
|
output_model_path, output_config_path = self._find_files(output_path)
|
||||||
# update paths in the config.json
|
# update paths in the config.json
|
||||||
self._update_paths(output_path, output_config_path)
|
self._update_paths(output_path, output_config_path)
|
||||||
|
@ -421,6 +446,39 @@ class ModelManager(object):
|
||||||
# remove the extracted folder
|
# remove the extracted folder
|
||||||
rmtree(os.path.join(output_folder, z.namelist()[0]))
|
rmtree(os.path.join(output_folder, z.namelist()[0]))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _download_tar_file(file_url, output_folder, progress_bar):
|
||||||
|
"""Download the github releases"""
|
||||||
|
# download the file
|
||||||
|
r = requests.get(file_url, stream=True)
|
||||||
|
# extract the file
|
||||||
|
try:
|
||||||
|
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||||
|
block_size = 1024 # 1 Kibibyte
|
||||||
|
if progress_bar:
|
||||||
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
|
temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||||
|
with open(temp_tar_name, "wb") as file:
|
||||||
|
for data in r.iter_content(block_size):
|
||||||
|
if progress_bar:
|
||||||
|
progress_bar.update(len(data))
|
||||||
|
file.write(data)
|
||||||
|
with tarfile.open(temp_tar_name) as t:
|
||||||
|
t.extractall(output_folder)
|
||||||
|
tar_names = t.getnames()
|
||||||
|
os.remove(temp_tar_name) # delete tar after extract
|
||||||
|
except tarfile.ReadError:
|
||||||
|
print(f" > Error: Bad tar file - {file_url}")
|
||||||
|
raise tarfile.ReadError # pylint: disable=raise-missing-from
|
||||||
|
# move the files to the outer path
|
||||||
|
for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
|
||||||
|
src_path = os.path.join(output_folder, tar_names[0], file_path)
|
||||||
|
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||||
|
if src_path != dst_path:
|
||||||
|
copyfile(src_path, dst_path)
|
||||||
|
# remove the extracted folder
|
||||||
|
rmtree(os.path.join(output_folder, tar_names[0]))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _download_model_files(file_urls, output_folder, progress_bar):
|
def _download_model_files(file_urls, output_folder, progress_bar):
|
||||||
"""Download the github releases"""
|
"""Download the github releases"""
|
||||||
|
|
|
@ -7,7 +7,9 @@ import pysbd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
from TTS.tts.models import setup_model as setup_tts_model
|
from TTS.tts.models import setup_model as setup_tts_model
|
||||||
|
from TTS.tts.models.vits import Vits
|
||||||
|
|
||||||
# pylint: disable=unused-wildcard-import
|
# pylint: disable=unused-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
|
@ -98,8 +100,12 @@ class Synthesizer(object):
|
||||||
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||||
|
|
||||||
if model_dir:
|
if model_dir:
|
||||||
self._load_tts_from_dir(model_dir, use_cuda)
|
if "fairseq" in model_dir:
|
||||||
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
self._load_fairseq_from_dir(model_dir, use_cuda)
|
||||||
|
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||||
|
else:
|
||||||
|
self._load_tts_from_dir(model_dir, use_cuda)
|
||||||
|
self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_segmenter(lang: str):
|
def _get_segmenter(lang: str):
|
||||||
|
@ -133,12 +139,23 @@ class Synthesizer(object):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.vc_model.cuda()
|
self.vc_model.cuda()
|
||||||
|
|
||||||
|
def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||||
|
"""Load the fairseq model from a directory.
|
||||||
|
|
||||||
|
We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory.
|
||||||
|
"""
|
||||||
|
self.tts_config = VitsConfig()
|
||||||
|
self.tts_model = Vits.init_from_config(self.tts_config)
|
||||||
|
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
|
||||||
|
self.tts_config = self.tts_model.config
|
||||||
|
if use_cuda:
|
||||||
|
self.tts_model.cuda()
|
||||||
|
|
||||||
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
|
||||||
"""Load the TTS model from a directory.
|
"""Load the TTS model from a directory.
|
||||||
|
|
||||||
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
|
We assume the model knows how to load itself from the directory and there is a config.json file in the directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config = load_config(os.path.join(model_dir, "config.json"))
|
config = load_config(os.path.join(model_dir, "config.json"))
|
||||||
self.tts_config = config
|
self.tts_config = config
|
||||||
self.tts_model = setup_tts_model(config)
|
self.tts_model = setup_tts_model(config)
|
||||||
|
|
|
@ -128,7 +128,7 @@ wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0],
|
||||||
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
|
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
|
||||||
```
|
```
|
||||||
|
|
||||||
Here is an example for a single speaker model.
|
#### Here is an example for a single speaker model.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Init TTS with the target model name
|
# Init TTS with the target model name
|
||||||
|
@ -137,7 +137,7 @@ tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False,
|
||||||
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
|
tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path=OUTPUT_PATH)
|
||||||
```
|
```
|
||||||
|
|
||||||
Example voice cloning with YourTTS in English, French and Portuguese:
|
#### Example voice cloning with YourTTS in English, French and Portuguese:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
|
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
|
||||||
|
@ -146,15 +146,16 @@ tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wa
|
||||||
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
|
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
|
||||||
```
|
```
|
||||||
|
|
||||||
Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
|
#### Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
|
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
|
||||||
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
|
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
|
||||||
```
|
```
|
||||||
|
|
||||||
Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can
|
#### Example voice cloning by a single speaker TTS model combining with the voice conversion model.
|
||||||
clone voices by using any model in 🐸TTS.
|
|
||||||
|
This way, you can clone voices by using any model in 🐸TTS.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tts = TTS("tts_models/de/thorsten/tacotron2-DDC")
|
tts = TTS("tts_models/de/thorsten/tacotron2-DDC")
|
||||||
|
@ -163,8 +164,11 @@ tts.tts_with_vc_to_file(
|
||||||
speaker_wav="target/speaker.wav",
|
speaker_wav="target/speaker.wav",
|
||||||
file_path="ouptut.wav"
|
file_path="ouptut.wav"
|
||||||
)
|
)
|
||||||
|
```
|
||||||
|
|
||||||
Example text to speech using [🐸Coqui Studio](https://coqui.ai) models. You can use all of your available speakers in the studio.
|
#### Example text to speech using [🐸Coqui Studio](https://coqui.ai) models.
|
||||||
|
|
||||||
|
You can use all of your available speakers in the studio.
|
||||||
[🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
|
[🐸Coqui Studio](https://coqui.ai) API token is required. You can get it from the [account page](https://coqui.ai/account).
|
||||||
You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.
|
You should set the `COQUI_STUDIO_TOKEN` environment variable to use the API token.
|
||||||
|
|
||||||
|
@ -194,3 +198,22 @@ api.list_speakers()
|
||||||
api.list_voices()
|
api.list_voices()
|
||||||
wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5)
|
wav, sample_rate = api.tts(text="This is a test.", speaker=api.speakers[0].name, emotion="Happy", speed=1.5)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Example text to speech using **Fairseq models in ~1100 languages** 🤯.
|
||||||
|
For these models use the following name format: `tts_models/<lang-iso_code>/fairseq/vits`.
|
||||||
|
|
||||||
|
You can find the list of language ISO codes [here](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) and learn about the Fairseq models [here](https://github.com/facebookresearch/fairseq/tree/main/examples/mms).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from TTS.api import TTS
|
||||||
|
api = TTS(model_name="tts_models/eng/fairseq/vits", gpu=True)
|
||||||
|
api.tts_to_file("This is a test.", file_path="output.wav")
|
||||||
|
|
||||||
|
# TTS with on the fly voice conversion
|
||||||
|
api = TTS("tts_models/deu/fairseq/vits")
|
||||||
|
api.tts_with_vc_to_file(
|
||||||
|
"Wie sage ich auf Italienisch, dass ich dich liebe?",
|
||||||
|
speaker_wav="target/speaker.wav",
|
||||||
|
file_path="ouptut.wav"
|
||||||
|
)
|
||||||
|
```
|
|
@ -60,7 +60,7 @@ if is_coqui_available:
|
||||||
self.assertIsNone(tts.languages)
|
self.assertIsNone(tts.languages)
|
||||||
|
|
||||||
def test_studio_model(self):
|
def test_studio_model(self):
|
||||||
tts = TTS(model_name="coqui_studio/en/Torcull Diarmuid/coqui_studio")
|
tts = TTS(model_name="coqui_studio/en/Zacharie Aimilios/coqui_studio")
|
||||||
tts.tts_to_file(text="This is a test.")
|
tts.tts_to_file(text="This is a test.")
|
||||||
|
|
||||||
# check speed > 2.0 raises error
|
# check speed > 2.0 raises error
|
||||||
|
@ -83,6 +83,10 @@ if is_coqui_available:
|
||||||
wav = tts.tts(text="This is a test.", speed=2.0, emotion="Sad")
|
wav = tts.tts(text="This is a test.", speed=2.0, emotion="Sad")
|
||||||
self.assertGreater(len(wav), 0)
|
self.assertGreater(len(wav), 0)
|
||||||
|
|
||||||
|
def test_fairseq_model(self): # pylint: disable=no-self-use
|
||||||
|
tts = TTS(model_name="tts_models/eng/fairseq/vits")
|
||||||
|
tts.tts_to_file(text="This is a test.")
|
||||||
|
|
||||||
def test_multi_speaker_multi_lingual_model(self):
|
def test_multi_speaker_multi_lingual_model(self):
|
||||||
tts = TTS()
|
tts = TTS()
|
||||||
tts.load_tts_model_by_name(tts.models[0]) # YourTTS
|
tts.load_tts_model_by_name(tts.models[0]) # YourTTS
|
||||||
|
|
Loading…
Reference in New Issue