🔥 XTTS implementation

This commit is contained in:
Eren Gölge 2023-09-08 12:40:31 +02:00 committed by Eren G??lge
parent 4d3f23b5d3
commit 4033db5f4b
28 changed files with 5866 additions and 21 deletions

View File

@ -111,6 +111,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea
- Delightful TTS: [paper](https://arxiv.org/abs/2110.12612)
### End-to-End Models
- ⓍTTS: [blog]()
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
- 🐸 YourTTS: [paper](https://arxiv.org/abs/2112.02418)
- 🐢 Tortoise: [orig. repo](https://github.com/neonbjb/tortoise-tts)
@ -248,11 +249,11 @@ tts.tts_with_vc_to_file(
```
#### Example using [🐸Coqui Studio](https://coqui.ai) voices.
You access all of your cloned voices and built-in speakers in [🐸Coqui Studio](https://coqui.ai).
You access all of your cloned voices and built-in speakers in [🐸Coqui Studio](https://coqui.ai).
To do this, you'll need an API token, which you can obtain from the [account page](https://coqui.ai/account).
After obtaining the API token, you'll need to configure the COQUI_STUDIO_TOKEN environment variable.
Once you have a valid API token in place, the studio speakers will be displayed as distinct models within the list.
Once you have a valid API token in place, the studio speakers will be displayed as distinct models within the list.
These models will follow the naming convention `coqui_studio/en/<studio_speaker_name>/coqui_studio`
```python

View File

@ -2,6 +2,18 @@
"tts_models": {
"multilingual": {
"multi-dataset": {
"xtts_v1": {
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
"hf_url": [
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
],
"default_vocoder": null,
"commit": "e9a1953e",
"license": "Coqui Community Model License",
"contact": "info@coqui.ai"
},
"your_tts": {
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip",
@ -881,4 +893,4 @@
}
}
}
}
}

View File

@ -105,6 +105,9 @@ class TTS(nn.Module):
@property
def is_multi_lingual(self):
# TODO: fix this
if "xtts" in self.model_name:
return True
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
return self.synthesizer.tts_model.language_manager.num_languages > 1
return False

View File

@ -392,7 +392,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
if args.encoder_path is not None:
encoder_path = args.encoder_path
encoder_config_path = args.encoder_config_path
device = args.device
if args.use_cuda:
device = "cuda"
@ -459,7 +459,9 @@ If you don't specify any models, then it uses LJSpeech based English model.
target_wav=args.target_wav,
)
elif model_dir is not None:
wav = synthesizer.tts(args.text, speaker_name=args.speaker_idx)
wav = synthesizer.tts(
args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
)
# save the results
print(" > Saving output to {}".format(args.out_path))

View File

@ -0,0 +1,90 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
@dataclass
class XttsConfig(BaseTTSConfig):
"""Defines parameters for XTTS TTS model.
Args:
model (str):
Model name. Do not change unless you know what you are doing.
model_args (XttsArgs):
Model architecture arguments. Defaults to `XttsArgs()`.
audio (XttsAudioConfig):
Audio processing configuration. Defaults to `XttsAudioConfig()`.
model_dir (str):
Path to the folder that has all the XTTS models. Defaults to None.
temperature (float):
Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`.
length_penalty (float):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length,
which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative),
length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
reperation_penalty (float):
The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`.
top_p (float):
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
Defaults to `0.8`.
cond_free_k (float):
Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`.
diffusion_temperature (float):
Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
are the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to `1.0`.
num_gpt_outputs (int):
Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
As XTTS is a probabilistic model, more samples means a higher probability of creating something "great".
Defaults to `16`.
decoder_iterations (int):
Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
however. Defaults to `30`.
decoder_sampler (str):
Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`.
Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
Example:
>>> from TTS.tts.configs.xtts_config import XttsConfig
>>> config = XttsConfig()
"""
model: str = "xtts"
# model specific params
model_args: XttsArgs = field(default_factory=XttsArgs)
audio: XttsAudioConfig = field(default_factory=XttsAudioConfig)
model_dir: str = None
languages: List[str] = field(
default_factory=lambda: ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"]
)
# inference params
temperature: float = 0.2
length_penalty: float = 1.0
repetition_penalty: float = 2.0
top_k: int = 50
top_p: float = 0.8
cond_free_k: float = 2.0
diffusion_temperature: float = 1.0
num_gpt_outputs: int = 16
decoder_iterations: int = 30
decoder_sampler: str = "ddim"

View File

@ -5,15 +5,14 @@ from tokenizers import Tokenizer
from TTS.tts.utils.text.cleaners import english_cleaners
DEFAULT_VOCAB_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
)
class VoiceBpeTokenizer:
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
def __init__(self, vocab_file=None, vocab_str=None):
self.tokenizer = None
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)
if vocab_str is not None:
self.tokenizer = Tokenizer.from_str(vocab_str)
def preprocess_text(self, txt):
txt = english_cleaners(txt)

File diff suppressed because it is too large Load Diff

393
TTS/tts/layers/xtts/dvae.py Normal file
View File

@ -0,0 +1,393 @@
import functools
from math import sqrt
import torch
import torch.distributed as distributed
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from einops import rearrange
def default(val, d):
return val if val is not None else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
def dvae_wav_to_mel(
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
):
mel_stft = torchaudio.transforms.MelSpectrogram(
n_fft=1024,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
norm="slaney",
).to(device)
wav = wav.to(device)
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device)
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
super().__init__()
self.dim = dim
self.n_embed = n_embed
self.decay = decay
self.eps = eps
self.balancing_heuristic = balancing_heuristic
self.codes = None
self.max_codes = 64000
self.codes_full = False
self.new_return_order = new_return_order
embed = torch.randn(dim, n_embed)
self.register_buffer("embed", embed)
self.register_buffer("cluster_size", torch.zeros(n_embed))
self.register_buffer("embed_avg", embed.clone())
def forward(self, input, return_soft_codes=False):
if self.balancing_heuristic and self.codes_full:
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
mask = torch.logical_or(h > 0.9, h < 0.01).unsqueeze(1)
ep = self.embed.permute(1, 0)
ea = self.embed_avg.permute(1, 0)
rand_embed = torch.randn_like(ep) * mask
self.embed = (ep * ~mask + rand_embed).permute(1, 0)
self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0)
self.cluster_size = self.cluster_size * ~mask.squeeze()
if torch.any(mask):
print(f"Reset {torch.sum(mask)} embedding codes.")
self.codes = None
self.codes_full = False
flatten = input.reshape(-1, self.dim)
dist = flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + self.embed.pow(2).sum(0, keepdim=True)
soft_codes = -dist
_, embed_ind = soft_codes.max(1)
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
embed_ind = embed_ind.view(*input.shape[:-1])
quantize = self.embed_code(embed_ind)
if self.balancing_heuristic:
if self.codes is None:
self.codes = embed_ind.flatten()
else:
self.codes = torch.cat([self.codes, embed_ind.flatten()])
if len(self.codes) > self.max_codes:
self.codes = self.codes[-self.max_codes :]
self.codes_full = True
if self.training:
embed_onehot_sum = embed_onehot.sum(0)
embed_sum = flatten.transpose(0, 1) @ embed_onehot
if distributed.is_initialized() and distributed.get_world_size() > 1:
distributed.all_reduce(embed_onehot_sum)
distributed.all_reduce(embed_sum)
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
n = self.cluster_size.sum()
cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
self.embed.data.copy_(embed_normalized)
diff = (quantize.detach() - input).pow(2).mean()
quantize = input + (quantize - input).detach()
if return_soft_codes:
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
elif self.new_return_order:
return quantize, embed_ind, diff
else:
return quantize, diff, embed_ind
def embed_code(self, embed_id):
return F.embedding(embed_id, self.embed.transpose(0, 1))
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
# values with the specified expected variance.
class DiscretizationLoss(nn.Module):
def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
super().__init__()
self.discrete_bins = discrete_bins
self.dim = dim
self.dist = torch.distributions.Normal(0, scale=expected_variance)
if store_past > 0:
self.record_past = True
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device="cpu"))
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device="cpu"))
self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
else:
self.record_past = False
def forward(self, x):
other_dims = set(range(len(x.shape))) - set([self.dim])
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
averaged = averaged - averaged.mean()
if self.record_past:
acc_count = self.accumulator.shape[0]
avg = averaged.detach().clone()
if self.accumulator_filled > 0:
averaged = torch.mean(self.accumulator, dim=0) * (acc_count - 1) / acc_count + averaged / acc_count
# Also push averaged into the accumulator.
self.accumulator[self.accumulator_index] = avg
self.accumulator_index += 1
if self.accumulator_index >= acc_count:
self.accumulator_index *= 0
if self.accumulator_filled <= 0:
self.accumulator_filled += 1
return torch.sum(-self.dist.log_prob(averaged))
class ResBlock(nn.Module):
def __init__(self, chan, conv, activation):
super().__init__()
self.net = nn.Sequential(
conv(chan, chan, 3, padding=1),
activation(),
conv(chan, chan, 3, padding=1),
activation(),
conv(chan, chan, 1),
)
def forward(self, x):
return self.net(x) + x
class UpsampledConv(nn.Module):
def __init__(self, conv, *args, **kwargs):
super().__init__()
assert "stride" in kwargs.keys()
self.stride = kwargs["stride"]
del kwargs["stride"]
self.conv = conv(*args, **kwargs)
def forward(self, x):
up = nn.functional.interpolate(x, scale_factor=self.stride, mode="nearest")
return self.conv(up)
# DiscreteVAE partially derived from lucidrains DALLE implementation
# Credit: https://github.com/lucidrains/DALLE-pytorch
class DiscreteVAE(nn.Module):
def __init__(
self,
positional_dims=2,
num_tokens=512,
codebook_dim=512,
num_layers=3,
num_resnet_blocks=0,
hidden_dim=64,
channels=3,
stride=2,
kernel_size=4,
use_transposed_convs=True,
encoder_norm=False,
activation="relu",
smooth_l1_loss=False,
straight_through=False,
normalization=None, # ((0.5,) * 3, (0.5,) * 3),
record_codes=False,
discretization_loss_averaging_steps=100,
lr_quantizer_args={},
):
super().__init__()
has_resblocks = num_resnet_blocks > 0
self.num_tokens = num_tokens
self.num_layers = num_layers
self.straight_through = straight_through
self.positional_dims = positional_dims
self.discrete_loss = DiscretizationLoss(
num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
)
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
if positional_dims == 2:
conv = nn.Conv2d
conv_transpose = nn.ConvTranspose2d
else:
conv = nn.Conv1d
conv_transpose = nn.ConvTranspose1d
if not use_transposed_convs:
conv_transpose = functools.partial(UpsampledConv, conv)
if activation == "relu":
act = nn.ReLU
elif activation == "silu":
act = nn.SiLU
else:
assert NotImplementedError()
enc_layers = []
dec_layers = []
if num_layers > 0:
enc_chans = [hidden_dim * 2**i for i in range(num_layers)]
dec_chans = list(reversed(enc_chans))
enc_chans = [channels, *enc_chans]
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
dec_chans = [dec_init_chan, *dec_chans]
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
pad = (kernel_size - 1) // 2
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act()))
if encoder_norm:
enc_layers.append(nn.GroupNorm(8, enc_out))
dec_layers.append(
nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride=stride, padding=pad), act())
)
dec_out_chans = dec_chans[-1]
innermost_dim = dec_chans[0]
else:
enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
dec_out_chans = hidden_dim
innermost_dim = hidden_dim
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
enc_layers.append(ResBlock(innermost_dim, conv, act))
if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
dec_layers.append(conv(dec_out_chans, channels, 1))
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
# take care of normalization within class
self.normalization = normalization
self.record_codes = record_codes
if record_codes:
self.codes = torch.zeros((1228800,), dtype=torch.long)
self.code_ind = 0
self.total_codes = 0
self.internal_step = 0
def norm(self, images):
if not self.normalization is not None:
return images
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
images = images.clone()
images.sub_(means).div_(stds)
return images
def get_debug_values(self, step, __):
if self.record_codes and self.total_codes > 0:
# Report annealing schedule
return {"histogram_codes": self.codes[: self.total_codes]}
else:
return {}
@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
img = self.norm(images)
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, _ = self.codebook(logits)
self.log_codes(codes)
return codes
def decode(self, img_seq):
self.log_codes(img_seq)
if hasattr(self.codebook, "embed_code"):
image_embeds = self.codebook.embed_code(img_seq)
else:
image_embeds = F.embedding(img_seq, self.codebook.codebook)
b, n, d = image_embeds.shape
kwargs = {}
if self.positional_dims == 1:
arrange = "b n d -> b d n"
else:
h = w = int(sqrt(n))
arrange = "b (h w) d -> b d h w"
kwargs = {"h": h, "w": w}
image_embeds = rearrange(image_embeds, arrange, **kwargs)
images = [image_embeds]
for layer in self.decoder:
images.append(layer(images[-1]))
return images[-1], images[-2]
def infer(self, img):
img = self.norm(img)
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, commitment_loss = self.codebook(logits)
return self.decode(codes)
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
# more lossy (but useful for determining network performance).
def forward(self, img):
img = self.norm(img)
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
sampled, codes, commitment_loss = self.codebook(logits)
sampled = sampled.permute((0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1))
if self.training:
out = sampled
for d in self.decoder:
out = d(out)
self.log_codes(codes)
else:
# This is non-differentiable, but gives a better idea of how the network is actually performing.
out, _ = self.decode(codes)
# reconstruction loss
recon_loss = self.loss_fn(img, out, reduction="none")
return recon_loss, commitment_loss, out
def log_codes(self, codes):
# This is so we can debug the distribution of codes being learned.
if self.record_codes and self.internal_step % 10 == 0:
codes = codes.flatten()
l = codes.shape[0]
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
self.codes[i : i + l] = codes.cpu()
self.code_ind = self.code_ind + l
if self.code_ind >= self.codes.shape[0]:
self.code_ind = 0
self.total_codes += 1
self.internal_step += 1

545
TTS/tts/layers/xtts/gpt.py Normal file
View File

@ -0,0 +1,545 @@
# ported from: https://github.com/neonbjb/tortoise-tts
import functools
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
super().__init__()
# nn.Embedding
self.emb = torch.nn.Embedding(seq_len, model_dim)
# Initializing this way is standard for GPT-2
self.emb.weight.data.normal_(mean=0.0, std=init)
self.relative = relative
self.seq_len = seq_len
def forward(self, x):
sl = x.shape[1]
if self.relative:
start = random.randint(sl, self.seq_len) - sl
return self.emb(torch.arange(start, start + sl, device=x.device))
else:
return self.emb(torch.arange(0, sl, device=x.device))
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def build_hf_gpt_transformer(
layers,
model_dim,
heads,
max_mel_seq_len,
max_text_seq_len,
max_prompt_len,
checkpointing,
):
"""
GPT-2 implemented by the HuggingFace library.
"""
from transformers import GPT2Config, GPT2Model
gpt_config = GPT2Config(
vocab_size=256, # Unused.
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing,
)
gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del gpt.wpe
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
mel_pos_emb = (
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_dim)
)
text_pos_emb = (
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_dim)
)
# gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
return gpt, mel_pos_emb, text_pos_emb, None, None
class GPT(nn.Module):
def __init__(
self,
start_text_token=261,
stop_text_token=0,
layers=8,
model_dim=512,
heads=8,
max_text_tokens=120,
max_mel_tokens=250,
max_prompt_tokens=70,
max_conditioning_inputs=1,
code_stride_len=1024,
number_text_tokens=256,
num_audio_tokens=8194,
start_audio_token=8192,
stop_audio_token=8193,
train_solo_embeddings=False,
checkpointing=False,
average_conditioning_embeddings=False,
label_smoothing=0.0,
):
"""
Args:
"""
super().__init__()
self.label_smoothing = label_smoothing
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.num_audio_tokens = num_audio_tokens
self.start_audio_token = start_audio_token
self.stop_audio_token = stop_audio_token
self.start_prompt_token = start_audio_token
self.stop_prompt_token = stop_audio_token
self.layers = layers
self.heads = heads
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
self.max_prompt_tokens = max_prompt_tokens
self.code_stride_len = code_stride_len
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.conditioning_dropout = nn.Dropout1d(0.1)
self.average_conditioning_embeddings = average_conditioning_embeddings
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
(
self.gpt,
self.mel_pos_embedding,
self.text_pos_embedding,
self.mel_layer_pos_embedding,
self.text_layer_pos_embedding,
) = build_hf_gpt_transformer(
layers,
model_dim,
heads,
self.max_mel_tokens,
self.max_text_tokens,
self.max_prompt_tokens,
checkpointing,
)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
def get_grad_norm_parameter_groups(self):
return {
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
"gpt": list(self.gpt.parameters()),
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}
def init_gpt_for_inference(self, kv_cache=True):
seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
gpt_config = GPT2Config(
vocab_size=self.max_mel_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.model_dim,
n_layer=self.layers,
n_head=self.heads,
gradient_checkpointing=False,
use_cache=True,
)
self.gpt_inference = GPT2InferenceModel(
gpt_config,
self.gpt,
self.mel_pos_embedding,
self.mel_embedding,
self.final_norm,
self.mel_head,
kv_cache=kv_cache,
)
self.gpt.wte = self.mel_embedding
def set_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_mel_padding(self, mel_input_tokens, code_lengths):
"""
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
that audio clip, reformats the tokens with stop_audio_token in place of the zero padding. This is required
preformatting to create a working TTS model.
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
for b in range(len(code_lengths)):
actual_end = code_lengths[b]
if actual_end < mel_input_tokens.shape[-1]:
mel_input_tokens[b, actual_end:] = self.stop_audio_token
return mel_input_tokens
def get_logits(
self,
first_inputs,
first_head,
second_inputs=None,
second_head=None,
prompt=None,
get_attns=False,
return_latent=False,
attn_mask_text=None,
attn_mask_mel=None,
):
if prompt is not None:
offset = prompt.shape[1]
if second_inputs is not None:
emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([prompt, first_inputs], dim=1)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
attn_mask = None
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
gpt_out = self.gpt(
inputs_embeds=emb,
return_dict=True,
output_attentions=get_attns,
attention_mask=attn_mask,
)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, offset:]
enc = self.final_norm(enc)
if return_latent:
return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
first_logits = enc[:, : first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0, 2, 1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1] :]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0, 2, 1)
return first_logits, second_logits
else:
return first_logits
def get_conditioning(self, speech_conditioning_input):
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds
def get_prompts(self, prompt_codes):
"""
Create a prompt from the mel codes. This is used to condition the model on the mel codes.
Pad the prompt with start and stop mel tokens.
"""
prompt = prompt_codes
if self.training:
lengths = []
# Compute the real prompt length based on the first encounter with the token 83 used for padding
for i in range(prompt_codes.shape[0]):
length = 0
for j in range(prompt_codes.shape[1]):
if prompt_codes[i, j] == 83:
break
else:
length += 1
lengths.append(length)
# prompt_len = random.randint(1, 9) # in secs
prompt_len = 3
prompt_len = prompt_len * 24 # in frames
if prompt_codes.shape[-1] >= prompt_len:
new_prompt = []
for i in range(prompt_codes.shape[0]):
if lengths[i] < prompt_len:
start = 0
else:
start = random.randint(0, lengths[i] - prompt_len)
prompt = prompt_codes[:, start : start + prompt_len]
# add start and stop tokens
prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token)
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
return prompt
def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
"""
cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s)
"""
conds = None
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
if sample:
_len_secs = random.randint(2, 6) # in secs
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
if cond_input.shape[-1] >= cond_seg_len:
new_conds = []
for i in range(cond_input.shape[0]):
cond_len = int(cond_lens[i] / 1024)
if cond_len < cond_seg_len:
start = 0
else:
start = random.randint(0, cond_len - cond_seg_len)
cond_vec = cond_input[i, :, start : start + cond_seg_len]
new_conds.append(cond_vec)
conds = torch.stack(new_conds, dim=0)
else:
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
cond_frame_len = int((22050 / 1024) * cond_seg_len)
conds = cond_input[:, :, -cond_frame_len:]
conds = self.conditioning_encoder(conds)
else:
# already computed
conds = cond_input.unsqueeze(1)
return conds
def forward(
self,
text_inputs,
text_lengths,
audio_codes,
wav_lengths,
cond_lens=None,
cond_mels=None,
cond_latents=None,
loss_weights=None,
return_attentions=False,
return_latent=False,
):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
cond_mels: MEL float tensor, (b, 1, 80,s)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
"""
# ❗ FIXIT
if self.max_conditioning_inputs == 0:
assert cond_mels is None, " ❗ cond_mels is not None, but max_conditioning_inputs == 0"
max_text_len = text_lengths.max()
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()
if max_mel_len > audio_codes.shape[-1]:
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
silence = True
for idx, l in enumerate(code_lengths):
length = l.item()
while silence:
if audio_codes[idx, length - 1] != 83:
break
length -= 1
code_lengths[idx] = length
# 💖 Lovely assertions
assert (
max_mel_len <= audio_codes.shape[-1]
), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
assert (
max_text_len <= text_inputs.shape[-1]
), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
# Append stop token to text inputs
text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
# Append silence token to mel codes
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
# Pad mel codes with stop_audio_token
audio_codes = self.set_mel_padding(audio_codes, code_lengths)
# Build input and target tensors
# Prepend start token to inputs and append stop token to targets
text_inputs, text_targets = self.set_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token
)
audio_codes, mel_targets = self.set_inputs_and_targets(
audio_codes, self.start_audio_token, self.stop_audio_token
)
# Set attn_mask
attn_mask_text = None
attn_mask_mel = None
if not return_latent:
attn_mask_text = torch.ones(
text_inputs.shape[0],
text_inputs.shape[1],
dtype=torch.bool,
device=text_inputs.device,
)
attn_mask_mel = torch.ones(
audio_codes.shape[0],
audio_codes.shape[1],
dtype=torch.bool,
device=audio_codes.device,
)
for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0
for idx, l in enumerate(code_lengths):
attn_mask_mel[idx, l + 1 :] = 0.0
# Compute text embeddings + positional embeddings
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
# Compute mel embeddings + positional embeddings
mel_emb = self.mel_embedding(audio_codes) + self.mel_pos_embedding(audio_codes)
# Compute speech conditioning input
if cond_latents is None:
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
# Get logits
sub = -5 # don't ask me why 😄
if self.training:
sub = -1
text_logits, mel_logits = self.get_logits(
text_emb,
self.text_head,
mel_emb,
self.mel_head,
prompt=cond_latents,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)
if return_latent:
return mel_logits[:, :sub] # sub to prevent bla.
if return_attentions:
return mel_logits
# Set paddings to -1 to ignore them in loss
for idx, l in enumerate(text_lengths):
text_targets[idx, l + 1 :] = -1
for idx, l in enumerate(code_lengths):
mel_targets[idx, l + 1 :] = -1
# check if stoptoken is in every row of mel_targets
assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[
0
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
# Compute losses
loss_text = F.cross_entropy(
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
)
loss_mel = F.cross_entropy(
mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
)
return loss_text.mean(), loss_mel.mean(), mel_logits
def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
self.compute_embeddings(cond_latents, text_inputs)
return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs)
def compute_embeddings(
self,
cond_latents,
text_inputs,
):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
emb = torch.cat([cond_latents, emb], dim=1)
self.gpt_inference.store_prefix_emb(emb)
gpt_inputs = torch.full(
(
emb.shape[0],
emb.shape[1] + 1, # +1 for the start_audio_token
),
fill_value=1,
dtype=torch.long,
device=text_inputs.device,
)
gpt_inputs[:, -1] = self.start_audio_token
return gpt_inputs
def generate(
self,
cond_latents,
text_inputs,
**hf_generate_kwargs,
):
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
gen = self.gpt_inference.generate(
gpt_inputs,
bos_token_id=self.start_audio_token,
pad_token_id=self.stop_audio_token,
eos_token_id=self.stop_audio_token,
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
return gen.sequences[:, gpt_inputs.shape[1] :], gen
return gen[:, gpt_inputs.shape[1] :]

View File

@ -0,0 +1,658 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
class GPT2InferenceModel(GPT2PreTrainedModel):
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
super().__init__(config)
self.transformer = gpt
self.pos_embedding = pos_emb
self.embeddings = embeddings
self.final_norm = norm
self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache
def store_prefix_emb(self, prefix_emb):
self.cached_prefix_emb = prefix_emb
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) # usually None
if not self.kv_cache:
past_key_values = None
# only last token for inputs_ids if past is defined in kwargs
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.cached_prefix_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
# Create embedding
prefix_len = self.cached_prefix_emb.shape[1]
if input_ids.shape[1] != 1:
gen_inputs = input_ids[:, prefix_len:]
gen_emb = self.embeddings(gen_inputs)
gen_emb = gen_emb + self.pos_embedding(gen_emb)
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
prefix_emb = self.cached_prefix_emb.repeat_interleave(
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
)
else:
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
emb = torch.cat([prefix_emb, gen_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
class LearnedPositionEmbeddings(nn.Module):
def __init__(self, seq_len, model_channels, init_std=0.02, relative=False):
super().__init__()
self.emb = nn.Embedding(seq_len, model_channels)
nn.init.normal_(self.emb.weight, mean=0.0, std=init_std)
self.relative = relative
def forward(self, x):
seq_len = x.shape[1]
if self.relative:
start = torch.randint(seq_len, (1,), device=x.device).item()
positions = torch.arange(start, start + seq_len, device=x.device)
else:
positions = torch.arange(seq_len, device=x.device)
return self.emb(positions)
def get_fixed_embedding(self, ind, dev):
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing):
"""
Initializes a GPT-2 model and its position embeddings for a text-to-speech system.
Args:
layers (int): Number of layers in the GPT-2 model.
model_channels (int): Dimension of the GPT-2 model.
heads (int): Number of heads in the GPT-2 model.
max_mel_seq_len (int): Maximum sequence length for the mel spectrogram.
max_text_seq_len (int): Maximum sequence length for the text.
max_prompt_len (int): Maximum length of the prompt.
checkpointing (bool): Whether to use gradient checkpointing.
Returns:
gpt (GPT2Model): GPT-2 model.
mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram.
text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text.
"""
gpt_config = GPT2Config(
vocab_size=123,
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
n_embd=model_channels,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing,
)
gpt = GPT2Model(gpt_config)
del gpt.wpe
del gpt.wte
gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels)
audio_pos_emb = (
LearnedPositionEmbeddings(max_mel_seq_len, model_channels)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_channels)
)
text_pos_emb = (
LearnedPositionEmbeddings(max_text_seq_len, model_channels)
if max_mel_seq_len != -1
else functools.partial(null_position_embeddings, dim=model_channels)
)
return gpt, audio_pos_emb, text_pos_emb
class XTTSGPTEncoder(nn.Module):
"""XTTS GPT Encoder model implementation.
Args:
start_text_token (int): Index of the start token in the text vocabulary.
stop_text_token (int): Index of the stop token in the text vocabulary.
n_layers (int): Number of layers in the GPT-2 model.
n_model_channels (int): Dimension of the GPT-2 model.
n_heads (int): Number of heads in the GPT-2 model.
max_text_tokens (int): Maximum number of text tokens.
max_audio_tokens (int): Maximum number of audio tokens.
max_prompt_tokens (int): Maximum number of prompt tokens.
audio_len_compression (int): Compression factor for the audio length.
number_text_tokens (int): Number of text tokens.
number_audio_codes (int): Number of audio codes.
start_mel_token (int): Index of the start token in the mel code vocabulary.
stop_mel_token (int): Index of the stop token in the mel code vocabulary.
checkpointing (bool): Whether or not to use gradient checkpointing at training.
"""
_inference_flag = False
def __init__(
self,
start_text_token=261,
stop_text_token=0,
n_layers=8,
n_model_channels=512,
n_heads=8,
max_text_tokens=120,
max_audio_tokens=250,
max_prompt_tokens=70,
audio_len_compression=1024,
number_text_tokens=256,
number_audio_codes=8194,
start_mel_token=8192,
stop_mel_token=8193,
checkpointing=True,
label_smoothing=0.0,
):
super().__init__()
self.label_smoothing = label_smoothing
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.number_audio_codes = number_audio_codes
self.start_mel_token = start_mel_token
self.stop_mel_token = stop_mel_token
self.start_prompt_token = start_mel_token
self.stop_prompt_token = stop_mel_token
self.n_layers = n_layers
self.n_heads = n_heads
self.n_model_channels = n_model_channels
self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
self.max_prompt_tokens = max_prompt_tokens
self.audio_len_compression = audio_len_compression
# embedding layers
self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels)
self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels)
# initialize the GPT-2 model
(
self.gpt,
self.audio_pos_embedding,
self.text_pos_embedding,
) = init_gpt(
n_layers,
n_model_channels,
n_heads,
self.max_audio_tokens,
self.max_text_tokens,
self.max_prompt_tokens,
checkpointing,
)
# output layers
self.final_norm = nn.LayerNorm(n_model_channels)
self.text_head = nn.Linear(n_model_channels, self.number_text_tokens)
self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes)
def get_grad_norm_parameter_groups(self):
return {
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
"gpt": list(self.gpt.parameters()),
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}
def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
self._inference_flag = True
seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens
gpt_config = GPT2Config(
vocab_size=self.max_audio_tokens,
n_positions=seq_length,
n_ctx=seq_length,
n_embd=self.n_model_channels,
n_layer=self.n_layers,
n_head=self.n_heads,
gradient_checkpointing=False,
use_cache=True,
)
self.inference_model = GPT2InferenceModel(
gpt_config,
self.gpt,
self.audio_pos_embedding,
self.audio_embedding,
self.final_norm,
self.mel_head,
kv_cache=kv_cache,
)
self.gpt.wte = self.audio_embedding
def set_inputs_and_targets(self, input, start_token, stop_token):
inp = F.pad(input, (1, 0), value=start_token)
tar = F.pad(input, (0, 1), value=stop_token)
return inp, tar
def set_audio_tokens_padding(self, audio_tokens, audio_token_lens):
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
for b in range(len(audio_token_lens)):
actual_end = audio_token_lens[b]
if actual_end < audio_tokens.shape[-1]:
audio_tokens[b, actual_end:] = self.stop_mel_token
return audio_tokens
def get_logits(
self,
speech_conditioning_inputs,
first_inputs,
first_head,
second_inputs=None,
second_head=None,
prompt=None,
get_attns=False,
return_latent=False,
attn_mask_text=None,
attn_mask_mel=None,
):
if prompt is not None and speech_conditioning_inputs is not None:
offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
if second_inputs is not None:
emb = torch.cat(
[speech_conditioning_inputs, prompt, first_inputs, second_inputs],
dim=1,
)
else:
emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
elif speech_conditioning_inputs is not None:
offset = speech_conditioning_inputs.shape[1]
if second_inputs is not None:
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
elif prompt is not None:
offset = prompt.shape[1]
if second_inputs is not None:
emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([prompt, first_inputs], dim=1)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
attn_mask = None
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
gpt_out = self.gpt(
inputs_embeds=emb,
return_dict=True,
output_attentions=get_attns,
attention_mask=attn_mask,
)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[:, offset:]
enc = self.final_norm(enc)
if return_latent:
return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
first_logits = enc[:, : first_inputs.shape[1]]
first_logits = first_head(first_logits)
first_logits = first_logits.permute(0, 2, 1)
if second_inputs is not None:
second_logits = enc[:, -second_inputs.shape[1] :]
second_logits = second_head(second_logits)
second_logits = second_logits.permute(0, 2, 1)
return first_logits, second_logits
else:
return first_logits
def get_conditioning(self, speech_conditioning_input):
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds
def get_prompts(self, prompt_codes):
prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token)
prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token)
return prompt
def forward(
self,
text_inputs,
text_lengths,
audio_codes,
wav_lengths,
prompt_codes,
return_attentions=False,
return_latent=False,
):
max_text_len = text_lengths.max()
# Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
# Like [..., 186, 45, 45, 83] where actually it should end with 186.
# We take last 3 codes to prevent abrupt ending of the audio.
# TODO: This is might need some testing.
mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = mel_lengths.max()
if max_mel_len > audio_codes.shape[-1]:
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
# silence aware lengths, skip the silence tokens at the end of the mel codes.
silence = True
for idx, l in enumerate(mel_lengths):
length = l.item()
while silence:
if audio_codes[idx, length - 1] != 83:
break
length -= 1
mel_lengths[idx] = length
# Lovely assertions
assert (
max_mel_len <= audio_codes.shape[-1]
), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
assert (
max_text_len <= text_inputs.shape[-1]
), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
# Append stop token to text inputs
text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
# Append silence token to mel codes
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
# Pad mel codes with STOP_MEL_TOKEN
audio_codes = self.set_mel_padding(audio_codes, mel_lengths)
# Compute speech conditioning input
conds = None
if speech_conditioning_input is not None:
if not return_latent:
# Compute speech conditioning input
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
if self.average_conditioning_embeddings:
conds = conds.mean(dim=1).unsqueeze(1)
else:
# already computed
conds = speech_conditioning_input.unsqueeze(1)
# Build input and target tensors
# Prepend start token to inputs and append stop token to targets
text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token)
# Set attn_mask
attn_mask_text = None
attn_mask_mel = None
if not return_latent:
attn_mask_text = torch.ones(
text_inputs.shape[0],
text_inputs.shape[1],
dtype=torch.bool,
device=text_inputs.device,
)
attn_mask_mel = torch.ones(
audio_codes.shape[0],
audio_codes.shape[1],
dtype=torch.bool,
device=audio_codes.device,
)
for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0
for idx, l in enumerate(mel_lengths):
attn_mask_mel[idx, l + 1 :] = 0.0
# Compute text embeddings + positional embeddings
# print(" > text input latent:", text_inputs)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
# Compute mel embeddings + positional embeddings
audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes)
# Compute prompt embeddings + positional embeddings
prompt = self.get_prompts(prompt_codes)
# prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt)
# dropout prompt embeddings
prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training)
# Get logits
sub = -4 # don't ask me why 😄
if self.training:
sub = -1
_, audio_logits = self.get_logits(
conds,
text_emb,
self.text_head,
audio_emb,
self.mel_head,
prompt=prompt_emb,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)
return audio_logits[:, :sub] # sub to prevent bla.
def compute_embeddings(
self,
speech_conditioning_latent,
text_inputs,
input_tokens=None,
prompt_codes=None,
pad_input_text=False,
):
"""Compute all the embeddings needed for inference."""
if pad_input_text and text_inputs.shape[1] < 250:
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
else:
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
print(" > Text inputs:", text_inputs)
if prompt_codes is not None:
prompt_codes = self.get_prompts(prompt_codes)
# prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
print(" > Prompt inputs:", prompt_codes)
print(" > Prompt inputs shape:", prompt_codes.shape)
emb = torch.cat([prompt_emb, emb], dim=1)
if speech_conditioning_latent is not None:
conds = speech_conditioning_latent.unsqueeze(1)
emb = torch.cat([conds, emb], dim=1)
self.inference_model.store_prefix_emb(emb)
fake_inputs = torch.full(
(
emb.shape[0],
emb.shape[1] + 1, # +1 for the start_mel_token
),
fill_value=1,
dtype=torch.long,
device=text_inputs.device,
)
fake_inputs[:, -1] = self.start_mel_token
if input_tokens is not None:
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
return fake_inputs
def inference(
self,
text_inputs,
input_tokens=None,
prompt_codes=None,
pad_input_text=False,
**hf_generate_kwargs,
):
if pad_input_text and text_inputs.shape[1] < 250:
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
else:
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
if prompt_codes is not None:
prompt_codes = self.get_prompts(prompt_codes)
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
emb = torch.cat([prompt_emb, emb], dim=1)
self.inference_model.store_prefix_emb(emb)
fake_inputs = torch.full(
(
emb.shape[0],
emb.shape[1] + 1, # +1 for the start_mel_token
),
fill_value=1,
dtype=torch.long,
device=text_inputs.device,
)
fake_inputs[:, -1] = self.start_mel_token
if input_tokens is not None:
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
gen = self.inference_model.generate(
fake_inputs,
bos_token_id=self.start_mel_token,
pad_token_id=self.stop_mel_token,
eos_token_id=self.stop_mel_token,
max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
**hf_generate_kwargs,
)
if "return_dict_in_generate" in hf_generate_kwargs:
return gen.sequences[:, fake_inputs.shape[1] :], gen
return gen[:, fake_inputs.shape[1] :]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,136 @@
import math
import torch
from torch import nn
from transformers import GPT2PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class GPT2InferenceModel(GPT2PreTrainedModel):
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
super().__init__(config)
self.transformer = gpt
self.pos_embedding = pos_emb
self.embeddings = embeddings
self.final_norm = norm
self.lm_head = nn.Sequential(norm, linear)
self.kv_cache = kv_cache
def store_prefix_emb(self, prefix_emb):
self.cached_prefix_emb = prefix_emb
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) # usually None
if not self.kv_cache:
past_key_values = None
# only last token for inputs_ids if past is defined in kwargs
if past_key_values is not None:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert self.cached_prefix_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
# Create embedding
prefix_len = self.cached_prefix_emb.shape[1]
if input_ids.shape[1] != 1:
gen_inputs = input_ids[:, prefix_len:]
gen_emb = self.embeddings(gen_inputs)
gen_emb = gen_emb + self.pos_embedding(gen_emb)
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
prefix_emb = self.cached_prefix_emb.repeat_interleave(
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
)
else:
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
emb = torch.cat([prefix_emb, gen_emb], dim=1)
else:
emb = self.embeddings(input_ids)
emb = emb + self.pos_embedding.get_fixed_embedding(
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
)
transformer_outputs = self.transformer(
inputs_embeds=emb,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return (lm_logits,) + transformer_outputs[1:]
return CausalLMOutputWithCrossAttentions(
loss=None,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past, beam_idx):
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)

View File

@ -0,0 +1,141 @@
# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts
import math
import torch
from torch import nn
from torch.nn import functional as F
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def normalization(channels):
groups = 32
if channels <= 16:
groups = 8
elif channels <= 64:
groups = 16
while channels % groups != 0:
groups = int(groups / 2)
assert groups > 2
return GroupNorm32(groups, channels)
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
class QKVAttention(nn.Module):
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None, qk_bias=0):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = weight + qk_bias
if mask is not None:
mask = mask.repeat(self.n_heads, 1, 1)
weight[mask.logical_not()] = -torch.inf
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
"""An attention block that allows spatial positions to attend to each other."""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
out_channels=None,
do_activation=False,
):
super().__init__()
self.channels = channels
out_channels = channels if out_channels is None else out_channels
self.do_activation = do_activation
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, out_channels * 3, 1)
self.attention = QKVAttention(self.num_heads)
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
def forward(self, x, mask=None, qk_bias=0):
b, c, *spatial = x.shape
if mask is not None:
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
if mask.shape[1] != x.shape[-1]:
mask = mask[:, : x.shape[-1], : x.shape[-1]]
x = x.reshape(b, c, -1)
x = self.norm(x)
if self.do_activation:
x = F.silu(x, inplace=True)
qkv = self.qkv(x)
h = self.attention(qkv, mask=mask, qk_bias=qk_bias)
h = self.proj_out(h)
xp = self.x_proj(x)
return (xp + h).reshape(b, xp.shape[1], *spatial)
class ConditioningEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
"""
x: (b, 80, s)
"""
h = self.init(x)
h = self.attn(h)
return h

View File

@ -0,0 +1,286 @@
import json
import os
import re
import inflect
import pandas as pd
import pypinyin
import torch
from num2words import num2words
from tokenizers import Tokenizer
from unidecode import unidecode
from TTS.tts.utils.text.cleaners import english_cleaners
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars" # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
# Regular expression matching whitespace:
_whitespace_re = re.compile(r"\s+")
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
text = text.replace('"', "")
return text
def expand_numbers_multilang(text, lang):
# TODO: Handle text more carefully. Currently, it just converts numbers without any context.
# Find all numbers in the input string
numbers = re.findall(r"\d+", text)
# Transliterate the numbers to text
for num in numbers:
transliterated_num = "".join(num2words(num, lang=lang))
text = text.replace(num, transliterated_num, 1)
return text
def transliteration_cleaners(text):
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def multilingual_cleaners(text, lang):
text = lowercase(text)
text = expand_numbers_multilang(text, lang)
text = collapse_whitespace(text)
text = text.replace('"', "")
if lang == "tr":
text = text.replace("İ", "i")
text = text.replace("Ö", "ö")
text = text.replace("Ü", "ü")
return text
def english_cleaners(text):
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
text = text.replace('"', "")
return text
def remove_extraneous_punctuation(word):
replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "": "-", "": "-", "`": "'", "ʼ": "'"}
replace = re.compile(
"|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL
)
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
extraneous = re.compile(r"^[@#%_=\$\^&\*\+\\]$")
word = extraneous.sub("", word)
return word
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
_whitespace_re = re.compile(r"\s+")
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def arabic_cleaners(text):
text = lowercase(text)
text = collapse_whitespace(text)
return text
def chinese_cleaners(text):
text = lowercase(text)
text = "".join(
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
)
return text
class VoiceBpeTokenizer:
def __init__(self, vocab_file=None, preprocess=None):
self.tokenizer = None
if vocab_file is not None:
with open(vocab_file, "r", encoding="utf-8") as f:
vocab = json.load(f)
self.language = vocab["model"]["language"] if "language" in vocab["model"] else None
if preprocess is None:
self.preprocess = "pre_tokenizer" in vocab and vocab["pre_tokenizer"]
else:
self.preprocess = preprocess
self.tokenizer = Tokenizer.from_file(vocab_file)
def preprocess_text(self, txt, lang):
if lang == "ja":
import pykakasi
kks = pykakasi.kakasi()
results = kks.convert(txt)
txt = " ".join([result["kana"] for result in results])
txt = basic_cleaners(txt)
elif lang == "en":
txt = english_cleaners(txt)
elif lang == "ar":
txt = arabic_cleaners(txt)
elif lang == "zh-cn":
txt = chinese_cleaners(txt)
else:
txt = multilingual_cleaners(txt, lang)
return txt
def encode(self, txt, lang):
if self.preprocess:
txt = self.preprocess_text(txt, lang)
txt = txt.replace(" ", "[SPACE]")
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
txt = txt.replace("[SPACE]", " ")
txt = txt.replace("[STOP]", "")
txt = txt.replace("[UNK]", "")
return txt

View File

@ -0,0 +1,385 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
MAX_WAV_VALUE = 32768.0
class KernelPredictor(torch.nn.Module):
"""Kernel predictor for the location-variable convolutions"""
def __init__(
self,
cond_channels,
conv_in_channels,
conv_out_channels,
conv_layers,
conv_kernel_size=3,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
kpnet_nonlinear_activation="LeakyReLU",
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
):
"""
Args:
cond_channels (int): number of channel for the conditioning sequence,
conv_in_channels (int): number of channel for the input sequence,
conv_out_channels (int): number of channel for the output sequence,
conv_layers (int): number of layers
"""
super().__init__()
self.conv_in_channels = conv_in_channels
self.conv_out_channels = conv_out_channels
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_convs = nn.ModuleList()
padding = (kpnet_conv_size - 1) // 2
for _ in range(3):
self.residual_convs.append(
nn.Sequential(
nn.Dropout(kpnet_dropout),
nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_hidden_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_kernel_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
self.bias_conv = nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
kpnet_bias_channels,
kpnet_conv_size,
padding=padding,
bias=True,
)
)
def forward(self, c):
"""
Args:
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
"""
batch, _, cond_length = c.shape
c = self.input_conv(c)
for residual_conv in self.residual_convs:
residual_conv.to(c.device)
c = c + residual_conv(c)
k = self.kernel_conv(c)
b = self.bias_conv(c)
kernels = k.contiguous().view(
batch,
self.conv_layers,
self.conv_in_channels,
self.conv_out_channels,
self.conv_kernel_size,
cond_length,
)
bias = b.contiguous().view(
batch,
self.conv_layers,
self.conv_out_channels,
cond_length,
)
return kernels, bias
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv[0])
nn.utils.remove_weight_norm(self.kernel_conv)
nn.utils.remove_weight_norm(self.bias_conv)
for block in self.residual_convs:
nn.utils.remove_weight_norm(block[1])
nn.utils.remove_weight_norm(block[3])
class LVCBlock(torch.nn.Module):
"""the location-variable convolutions"""
def __init__(
self,
in_channels,
cond_channels,
stride,
dilations=[1, 3, 9, 27],
lReLU_slope=0.2,
conv_kernel_size=3,
cond_hop_length=256,
kpnet_hidden_channels=64,
kpnet_conv_size=3,
kpnet_dropout=0.0,
):
super().__init__()
self.cond_hop_length = cond_hop_length
self.conv_layers = len(dilations)
self.conv_kernel_size = conv_kernel_size
self.kernel_predictor = KernelPredictor(
cond_channels=cond_channels,
conv_in_channels=in_channels,
conv_out_channels=2 * in_channels,
conv_layers=len(dilations),
conv_kernel_size=conv_kernel_size,
kpnet_hidden_channels=kpnet_hidden_channels,
kpnet_conv_size=kpnet_conv_size,
kpnet_dropout=kpnet_dropout,
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
)
self.convt_pre = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.ConvTranspose1d(
in_channels,
in_channels,
2 * stride,
stride=stride,
padding=stride // 2 + stride % 2,
output_padding=stride % 2,
)
),
)
self.conv_blocks = nn.ModuleList()
for dilation in dilations:
self.conv_blocks.append(
nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.Conv1d(
in_channels,
in_channels,
conv_kernel_size,
padding=dilation * (conv_kernel_size - 1) // 2,
dilation=dilation,
)
),
nn.LeakyReLU(lReLU_slope),
)
)
def forward(self, x, c):
"""forward propagation of the location-variable convolutions.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length)
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
Returns:
Tensor: the output sequence (batch, in_channels, in_length)
"""
_, in_channels, _ = x.shape # (B, c_g, L')
x = self.convt_pre(x) # (B, c_g, stride * L')
kernels, bias = self.kernel_predictor(c)
for i, conv in enumerate(self.conv_blocks):
output = conv(x) # (B, c_g, stride * L')
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
output = self.location_variable_convolution(
output, k, b, hop_size=self.cond_hop_length
) # (B, 2 * c_g, stride * L'): LVC
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
output[:, in_channels:, :]
) # (B, c_g, stride * L'): GAU
return x
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
Args:
x (Tensor): the input sequence (batch, in_channels, in_length).
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
dilation (int): the dilation of convolution.
hop_size (int): the hop_size of the conditioning sequence.
Returns:
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
x = x.unfold(
3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
o = o + bias
o = o.contiguous().view(batch, out_channels, -1)
return o
def remove_weight_norm(self):
self.kernel_predictor.remove_weight_norm()
nn.utils.remove_weight_norm(self.convt_pre[1])
for block in self.conv_blocks:
nn.utils.remove_weight_norm(block[1])
class UnivNetGenerator(nn.Module):
"""
UnivNet Generator
Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
"""
def __init__(
self,
noise_dim=64,
channel_size=32,
dilations=[1, 3, 9, 27],
strides=[8, 8, 4],
lReLU_slope=0.2,
kpnet_conv_size=3,
# Below are MEL configurations options that this generator requires.
hop_length=256,
n_mel_channels=100,
):
super(UnivNetGenerator, self).__init__()
self.mel_channel = n_mel_channels
self.noise_dim = noise_dim
self.hop_length = hop_length
channel_size = channel_size
kpnet_conv_size = kpnet_conv_size
self.res_stack = nn.ModuleList()
hop_length = 1
for stride in strides:
hop_length = stride * hop_length
self.res_stack.append(
LVCBlock(
channel_size,
n_mel_channels,
stride=stride,
dilations=dilations,
lReLU_slope=lReLU_slope,
cond_hop_length=hop_length,
kpnet_conv_size=kpnet_conv_size,
)
)
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(),
)
def forward(self, c, z):
"""
Args:
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
z (Tensor): the noise sequence (batch, noise_dim, in_length)
"""
z = self.conv_pre(z) # (B, c_g, L)
for res_block in self.res_stack:
res_block.to(z.device)
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
z = self.conv_post(z) # (B, 1, L * 256)
return z
def eval(self, inference=False):
super(UnivNetGenerator, self).eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
for layer in self.conv_post:
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
for res_block in self.res_stack:
res_block.remove_weight_norm()
def inference(self, c, z=None):
# pad input mel with zeros to cut artifact
# see https://github.com/seungwonpark/melgan/issues/8
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
mel = torch.cat((c, zero), dim=2)
if z is None:
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
audio = self.forward(mel, z)
audio = audio[:, :, : -(self.hop_length * 10)]
audio = audio.clamp(min=-1, max=1)
return audio
if __name__ == "__main__":
model = UnivNetGenerator()
c = torch.randn(3, 100, 10)
z = torch.randn(3, 64, 10)
print(c.shape)
y = model(c, z)
print(y.shape)
assert y.shape == torch.Size([3, 1, 2560])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)

654
TTS/tts/models/xtts.py Normal file
View File

@ -0,0 +1,654 @@
import os
from contextlib import contextmanager
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import torchaudio
from coqpit import Coqpit
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
def load_audio(audiopath, sr=22050):
"""
Load an audio file from disk and resample it to the specified sampling rate.
Args:
audiopath (str): Path to the audio file.
sr (int): Target sampling rate.
Returns:
Tensor: Audio waveform tensor with shape (1, T), where T is the number of samples.
"""
audio, sampling_rate = torchaudio.load(audiopath)
if len(audio.shape) > 1:
if audio.shape[0] < 5:
audio = audio[0]
else:
assert audio.shape[1] < 5
audio = audio[:, 0]
if sampling_rate != sr:
resampler = torchaudio.transforms.Resample(sampling_rate, sr)
audio = resampler(audio)
audio = audio.clamp_(-1, 1)
return audio.unsqueeze(0)
def wav_to_mel_cloning(
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
):
"""
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
Args:
wav (torch.Tensor): Input waveform tensor.
mel_norms_file (str): Path to mel-spectrogram normalization file.
mel_norms (torch.Tensor): Mel-spectrogram normalization tensor.
device (torch.device): Device to use for computation.
Returns:
torch.Tensor: Mel-spectrogram tensor.
"""
mel_stft = torchaudio.transforms.MelSpectrogram(
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
norm="slaney",
).to(device)
wav = wav.to(device)
mel = mel_stft(wav)
mel = torch.log(torch.clamp(mel, min=1e-5))
if mel_norms is None:
mel_norms = torch.load(mel_norms_file, map_location=device)
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
return mel
def pad_or_truncate(t, length):
"""
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
Args:
t (torch.Tensor): The input tensor to be padded or truncated.
length (int): The desired length of the tensor.
Returns:
torch.Tensor: The padded or truncated tensor.
"""
tp = t[..., :length]
if t.shape[-1] == length:
tp = t
elif t.shape[-1] < length:
tp = F.pad(t, (0, length - t.shape[-1]))
return tp
def load_discrete_vocoder_diffuser(
trained_diffusion_steps=4000,
desired_diffusion_steps=200,
cond_free=True,
cond_free_k=1,
sampler="ddim",
):
"""
Load a GaussianDiffusion instance configured for use as a decoder.
Args:
trained_diffusion_steps (int): The number of diffusion steps used during training.
desired_diffusion_steps (int): The number of diffusion steps to use during inference.
cond_free (bool): Whether to use a conditioning-free model.
cond_free_k (int): The number of samples to use for conditioning-free models.
sampler (str): The name of the sampler to use.
Returns:
A SpacedDiffusion instance configured with the given parameters.
"""
return SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type="epsilon",
model_var_type="learned_range",
loss_type="mse",
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
conditioning_free=cond_free,
conditioning_free_k=cond_free_k,
sampler=sampler,
)
def do_spectrogram_diffusion(
diffusion_model,
diffuser,
latents,
conditioning_latents,
temperature=1,
):
"""
Generate a mel-spectrogram using a diffusion model and a diffuser.
Args:
diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise.
latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes.
conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes.
temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1.
Returns:
torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram.
"""
with torch.no_grad():
output_seq_len = (
latents.shape[1] * 4 * 24000 // 22050
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion_model.timestep_independent(
latents, conditioning_latents, output_seq_len, False
)
noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.sample_loop(
diffusion_model,
output_shape,
noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=False,
)
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
@dataclass
class XttsAudioConfig(Coqpit):
"""
Configuration class for audio-related parameters in the XTTS model.
Args:
sample_rate (int): The sample rate in which the GPT operates.
diffusion_sample_rate (int): The sample rate of the diffusion audio waveform.
output_sample_rate (int): The sample rate of the output audio waveform.
"""
sample_rate: int = 22050
diffusion_sample_rate: int = 24000
output_sample_rate: int = 24000
@dataclass
class XttsArgs(Coqpit):
"""A dataclass to represent XTTS model arguments that define the model structure.
Args:
gpt_batch_size (int): The size of the auto-regressive batch.
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
For GPT model:
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
ar_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
ar_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
ar_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
For DiffTTS model:
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10.
diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100.
diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200.
diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024.
diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193.
diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0.
diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False.
diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16.
diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0.
diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0.
"""
gpt_batch_size: int = 1
enable_redaction: bool = False
lazy_load: bool = True
kv_cache: bool = True
gpt_checkpoint: str = None
clvp_checkpoint: str = None
decoder_checkpoint: str = None
num_chars: int = 255
# XTTS GPT Encoder params
tokenizer_file: str = ""
gpt_max_audio_tokens: int = 605
gpt_max_text_tokens: int = 402
gpt_max_prompt_tokens: int = 70
gpt_layers: int = 30
gpt_n_model_channels: int = 1024
gpt_n_heads: int = 16
gpt_number_text_tokens: int = None
gpt_start_text_token: int = None
gpt_stop_text_token: int = None
gpt_num_audio_tokens: int = 8194
gpt_start_audio_token: int = 8192
gpt_stop_audio_token: int = 8193
# Diffusion Decoder params
diff_model_channels: int = 1024
diff_num_layers: int = 10
diff_in_channels: int = 100
diff_out_channels: int = 200
diff_in_latent_channels: int = 1024
diff_in_tokens: int = 8193
diff_dropout: int = 0
diff_use_fp16: bool = False
diff_num_heads: int = 16
diff_layer_drop: int = 0
diff_unconditioned_percentage: int = 0
# constants
duration_const: int = 102400
class Xtts(BaseTTS):
"""ⓍTTS model implementation.
Currently it only supports inference.
Examples:
>>> from TTS.tts.configs.xtts_config import XttsConfig
>>> from TTS.tts.models.xtts import Xtts
>>> config = XttsConfig()
>>> model = Xtts.inif_from_config(config)
>>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True)
"""
def __init__(self, config: Coqpit):
super().__init__(config, ap=None, tokenizer=None)
self.lazy_load = self.args.lazy_load
self.mel_stats_path = None
self.config = config
self.gpt_checkpoint = self.args.gpt_checkpoint
self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed
self.models_dir = config.model_dir
self.gpt_batch_size = self.args.gpt_batch_size
self.tokenizer = VoiceBpeTokenizer()
self.gpt = None
self.diffusion_decoder = None
self.init_models()
self.register_buffer("mel_stats", torch.ones(80))
def init_models(self):
"""Initialize the models. We do it here since we need to load the tokenizer first."""
if self.tokenizer.tokenizer is not None:
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
if self.args.gpt_number_text_tokens:
self.gpt = GPT(
layers=self.args.gpt_layers,
model_dim=self.args.gpt_n_model_channels,
start_text_token=self.args.gpt_start_text_token,
stop_text_token=self.args.gpt_stop_text_token,
heads=self.args.gpt_n_heads,
max_text_tokens=self.args.gpt_max_text_tokens,
max_mel_tokens=self.args.gpt_max_audio_tokens,
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
number_text_tokens=self.args.gpt_number_text_tokens,
num_audio_tokens=self.args.gpt_num_audio_tokens,
start_audio_token=self.args.gpt_start_audio_token,
stop_audio_token=self.args.gpt_stop_audio_token,
)
self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
in_channels=self.args.diff_in_channels,
out_channels=self.args.diff_out_channels,
in_latent_channels=self.args.diff_in_latent_channels,
in_tokens=self.args.diff_in_tokens,
dropout=self.args.diff_dropout,
use_fp16=self.args.diff_use_fp16,
num_heads=self.args.diff_num_heads,
layer_drop=self.args.diff_layer_drop,
unconditioned_percentage=self.args.diff_unconditioned_percentage,
)
self.vocoder = UnivNetGenerator()
@property
def device(self):
return next(self.parameters()).device
@contextmanager
def lazy_load_model(self, model):
"""Context to load a model on demand.
Args:
model (nn.Module): The model to be loaded.
"""
if self.lazy_load:
yield model
else:
m = model.to(self.device)
yield m
m = model.cpu()
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio.
Args:
audio_path (str): Path to the audio file.
length (int): Length of the audio in seconds. Defaults to 3.
"""
audio = load_audio(audio_path)
audio = audio[:, : 22050 * length]
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
return cond_latent.transpose(1, 2)
def get_diffusion_cond_latents(
self,
audio_path,
):
from math import ceil
diffusion_conds = []
CHUNK_SIZE = 102400
audio = load_audio(audio_path, 24000)
for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)):
current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
cond_mel = wav_to_univnet_mel(
current_sample.to(self.device),
do_normalization=False,
device=self.device,
)
diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1)
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
return diffusion_latent
def get_conditioning_latents(
self,
audio_path,
gpt_cond_len=3,
):
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
diffusion_cond_latents = self.get_diffusion_cond_latents(
audio_path,
)
return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)
def synthesize(self, text, config, speaker_wav, language, **kwargs):
"""Synthesize speech with the given input text.
Args:
text (str): Input text.
config (XttsConfig): Config with inference parameters.
speaker_wav (str): Path to the speaker audio file for cloning.
language (str): Language ID of the speaker.
**kwargs: Inference settings. See `inference()`.
Returns:
A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference,
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
as latents used at inference.
"""
# Make the synthesizer happy 🥳
if isinstance(speaker_wav, list):
speaker_wav = speaker_wav[0]
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
"""
inference with config
"""
assert (
language in self.config.languages
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
# Use generally found best tuning knobs for generation.
settings = {
"temperature": config.temperature,
"length_penalty": config.length_penalty,
"repetition_penalty": config.repetition_penalty,
"top_k": config.top_k,
"top_p": config.top_p,
"cond_free_k": config.cond_free_k,
"diffusion_temperature": config.diffusion_temperature,
"decoder_iterations": config.decoder_iterations,
"decoder_sampler": config.decoder_sampler,
}
settings.update(kwargs) # allow overriding of preset settings with kwargs
return self.inference(text, ref_audio_path, language, **settings)
@torch.no_grad()
def inference(
self,
text,
ref_audio_path,
language,
# GPT inference
temperature=0.65,
length_penalty=1,
repetition_penalty=2.0,
top_k=50,
top_p=0.85,
gpt_cond_len=4,
do_sample=True,
# Decoder inference
decoder_iterations=100,
cond_free=True,
cond_free_k=2,
diffusion_temperature=1.0,
decoder_sampler="ddim",
**hf_generate_kwargs,
):
"""
This function produces an audio clip of the given text being spoken with the given reference voice.
Args:
text: (str) Text to be spoken.
ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3
seconds long.
language: (str) Language of the voice to be generated.
temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65.
length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the
model to produce more terse outputs. Defaults to 1.0.
repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during
decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0.
top_k: (int) K value used in top-k sampling. [0,inf]. Lower values mean the decoder produces more "likely"
(aka boring) outputs. Defaults to 50.
top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely"
(aka boring) outputs. Defaults to 0.8.
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 3 seconds.
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
Generally a value above 250 is not noticeably better, however. Defaults to 100.
cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion
performs two forward passes for each diffusion step: one with the outputs of the autoregressive model
and one with no conditioning priors. The output of the two is blended according to the cond_free_k
value below. Conditioning-free diffusion is the real deal, and dramatically improves realism.
Defaults to True.
cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the
conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the
conditioning-free signal. Defaults to 2.0.
diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1].
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
Defaults to 1.0.
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
here: https://huggingface.co/docs/transformers/internal/generation_utils
Returns:
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
Sample rate is 24kHz.
"""
text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
assert (
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
(
gpt_cond_latent,
diffusion_conditioning,
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=decoder_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
sampler=decoder_sampler,
)
with torch.no_grad():
self.gpt = self.gpt.to(self.device)
with self.lazy_load_model(self.gpt) as gpt:
gpt_codes = gpt.generate(
cond_latents=gpt_cond_latent,
text_inputs=text_tokens,
input_tokens=None,
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_return_sequences=self.gpt_batch_size,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
output_attentions=False,
**hf_generate_kwargs,
)
with self.lazy_load_model(self.gpt) as gpt:
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = gpt(
text_tokens,
text_len,
gpt_codes,
expected_output_len,
cond_latents=gpt_cond_latent,
return_attentions=False,
return_latent=True,
)
silence_token = 83
ctokens = 0
for k in range(gpt_codes.shape[-1]):
if gpt_codes[0, k] == silence_token:
ctokens += 1
else:
ctokens = 0
if ctokens > 8:
gpt_latents = gpt_latents[:, :k]
break
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
mel = do_spectrogram_diffusion(
diffusion,
diffuser,
gpt_latents,
diffusion_conditioning,
temperature=diffusion_temperature,
)
with self.lazy_load_model(self.vocoder) as vocoder:
wav = vocoder.inference(mel)
return {"wav": wav.cpu().numpy().squeeze()}
def forward(self):
raise NotImplementedError("XTTS Training is not implemented")
def eval_step(self):
raise NotImplementedError("XTTS Training is not implemented")
@staticmethod
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
return Xtts(config)
def eval(self): # pylint: disable=redefined-builtin
"""Sets the model to evaluation mode. Overrides the default eval() method to also set the GPT model to eval mode."""
self.gpt.init_gpt_for_inference()
super().eval()
def load_checkpoint(
self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True
):
"""
Loads a checkpoint from disk and initializes the model's state and tokenizer.
Args:
config (dict): The configuration dictionary for the model.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
Returns:
None
"""
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")):
self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json"))
self.init_models()
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
self.load_state_dict(load_fsspec(model_path)["model"], strict=strict)
if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
self.gpt.eval()
self.diffusion_decoder.eval()
self.vocoder.eval()
def train_step(self):
raise NotImplementedError("XTTS Training is not implemented")

View File

@ -8,7 +8,9 @@ def init():
import jpype
import jpype.imports
except ModuleNotFoundError:
raise ModuleNotFoundError("Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`.")
raise ModuleNotFoundError(
"Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`."
)
try:
jar_path = os.environ["BEL_FANETYKA_JAR"]
@ -31,4 +33,5 @@ def belarusian_text_to_phonemes(text: str) -> str:
init()
from org.alex73.fanetyka.impl import FanetykaText
return str(FanetykaText(finder, text).ipa)

View File

@ -1,6 +1,6 @@
from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer
from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer
from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
from TTS.tts.utils.text.phonemizers.ko_kr_phonemizer import KO_KR_Phonemizer

View File

@ -1,7 +1,7 @@
from typing import Dict
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
_DEF_BE_PUNCS = ",!." # TODO

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import datetime
import importlib
import logging
import os
import re
import subprocess
@ -219,3 +220,22 @@ class KeepAverage:
def update_values(self, value_dict):
for key, value in value_dict.items():
self.update_value(key, value)
def get_timestamp():
return datetime.now().strftime("%y%m%d-%H%M%S")
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
lg = logging.getLogger(logger_name)
formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S")
lg.setLevel(level)
if tofile:
log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp()))
fh = logging.FileHandler(log_file, mode="w")
fh.setFormatter(formatter)
lg.addHandler(fh)
if screen:
sh = logging.StreamHandler()
sh.setFormatter(formatter)
lg.addHandler(sh)

View File

@ -334,7 +334,7 @@ class ModelManager(object):
output_model_path = output_path
output_config_path = None
if (
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name
model not in ["tortoise-v2", "bark", "xtts_v1"] and "fairseq" not in model_name
): # TODO:This is stupid but don't care for now.
output_model_path, output_config_path = self._find_files(output_path)
# update paths in the config.json

View File

@ -242,7 +242,11 @@ class Synthesizer(nn.Module):
wav (List[int]): waveform as a list of values.
path (str): output path to save the waveform.
"""
wav = np.array(wav)
# if tensor convert to numpy
if torch.is_tensor(wav):
wav = wav.cpu().numpy()
if isinstance(wav, list):
wav = np.array(wav)
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate)
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
@ -334,7 +338,7 @@ class Synthesizer(nn.Module):
elif language_name and isinstance(language_name, str):
try:
language_id = self.tts_model.language_manager.name_to_id[language_name]
language_id = self.tts_model.language_manager.name_to_id[language_id]
except KeyError as e:
raise ValueError(
f" [!] Looks like you use a multi-lingual model. "
@ -374,6 +378,8 @@ class Synthesizer(nn.Module):
speaker_id=speaker_name,
voice_dirs=self.voice_dir,
d_vector=speaker_embedding,
speaker_wav=speaker_wav,
language=language_name,
**kwargs,
)
else:

View File

@ -53,6 +53,7 @@
models/overflow.md
models/tortoise.md
models/bark.md
models/xtts.md
.. toctree::
:maxdepth: 2

View File

@ -1,4 +1,4 @@
# Bark 🐶
# 🐶 Bark
Bark is a multi-lingual TTS model created by [Suno-AI](https://www.suno.ai/). It can generate conversational speech as well as music and sound effects.
It is architecturally very similar to Google's [AudioLM](https://arxiv.org/abs/2209.03143). For more information, please refer to the [Suno-AI's repo](https://github.com/suno-ai/bark).

View File

@ -1,4 +1,4 @@
# Tortoise 🐢
# 🐢 Tortoise
Tortoise is a very expressive TTS system with impressive voice cloning capabilities. It is based on an GPT like autogressive acoustic model that converts input
text to discritized acouistic tokens, a diffusion model that converts these tokens to melspeectrogram frames and a Univnet vocoder to convert the spectrograms to
the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS.

108
docs/source/models/xtts.md Normal file
View File

@ -0,0 +1,108 @@
# ⓍTTS
ⓍTTS is a super cool Text-to-Speech model that lets you clone voices in different languages by using just a quick 3-second audio clip. Built on the 🐢Tortoise,
ⓍTTS has important model changes that make cross-language voice cloning and multi-lingual speech generation super easy.
There is no need for an excessive amount of training data that spans countless hours.
This is the same model that powers [Coqui Studio](https://coqui.ai/), and [Coqui API](https://docs.coqui.ai/docs), however we apply
a few tricks to make it faster and support streaming inference.
### Features
- Voice cloning with just a 3-second audio clip.
- Cross-language voice cloning.
- Multi-lingual speech generation.
- 24khz sampling rate.
### Code
Current implementation only supports inference.
### Languages
As of now, XTTS-v1 supports 13 languages: English, Spanish, French, German, Italian, Portuguese,
Polish, Turkish, Russian, Dutch, Czech, Arabic, and Chinese (Simplified).
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
### License
This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml).
### Contact
Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
You can also mail us at info@coqui.ai.
Using 🐸TTS API:
```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True)
# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav",
speaker_wav="/path/to/target/speaker.wav",
language="en")
# generate speech by cloning a voice using custom settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
file_path="output.wav",
speaker_wav="/path/to/target/speaker.wav",
language="en",
decoder_iterations=30)
```
Using 🐸TTS Command line:
```console
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
--text "Bugün okula gitmek istemiyorum." \
--speaker_wav /path/to/target/speaker.wav \
--language_idx tr \
--use_cuda true
```
Using model directly:
```python
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
config = XttsConfig()
config.load_json("/path/to/xtts/config.json")
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)
model.cuda()
outputs = model.synthesize(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
config,
speaker_wav="/data/TTS-public/_refclips/3.wav",
gpt_cond_len=3,
language="en",
)
```
## Important resources & papers
- VallE: https://arxiv.org/abs/2301.02111
- Tortoise Repo: https://github.com/neonbjb/tortoise-tts
- Faster implementation: https://github.com/152334H/tortoise-tts-fast
- Univnet: https://arxiv.org/abs/2106.07889
- Latent Diffusion:https://arxiv.org/abs/2112.10752
- DALL-E: https://arxiv.org/abs/2102.12092
## XttsConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.xtts_config.XttsConfig
:members:
```
## XttsArgs
```{eval-rst}
.. autoclass:: TTS.tts.models.xtts.XttsArgs
:members:
```
## XTTS Model
```{eval-rst}
.. autoclass:: TTS.tts.models.xtts.XTTS
:members:
```

View File

@ -60,7 +60,7 @@ config = GlowTTSConfig(
output_path=output_path,
add_blank=True,
datasets=[dataset_config],
# characters=characters,
# characters=characters,
enable_eos_bos_chars=True,
mixed_precision=False,
save_step=10000,

View File

@ -1,6 +1,6 @@
import os
import warnings
import unittest
import warnings
from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
@ -17,7 +17,8 @@ class TestText(unittest.TestCase):
except KeyError:
warnings.warn(
"You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file to test Belarusian phonemizer",
Warning)
Warning,
)
return
for line in _TEST_CASES.strip().split("\n"):