mirror of https://github.com/coqui-ai/TTS.git
🔥 XTTS implementation
This commit is contained in:
parent
4d3f23b5d3
commit
4033db5f4b
|
@ -111,6 +111,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea
|
|||
- Delightful TTS: [paper](https://arxiv.org/abs/2110.12612)
|
||||
|
||||
### End-to-End Models
|
||||
- ⓍTTS: [blog]()
|
||||
- VITS: [paper](https://arxiv.org/pdf/2106.06103)
|
||||
- 🐸 YourTTS: [paper](https://arxiv.org/abs/2112.02418)
|
||||
- 🐢 Tortoise: [orig. repo](https://github.com/neonbjb/tortoise-tts)
|
||||
|
@ -248,11 +249,11 @@ tts.tts_with_vc_to_file(
|
|||
```
|
||||
|
||||
#### Example using [🐸Coqui Studio](https://coqui.ai) voices.
|
||||
You access all of your cloned voices and built-in speakers in [🐸Coqui Studio](https://coqui.ai).
|
||||
You access all of your cloned voices and built-in speakers in [🐸Coqui Studio](https://coqui.ai).
|
||||
To do this, you'll need an API token, which you can obtain from the [account page](https://coqui.ai/account).
|
||||
After obtaining the API token, you'll need to configure the COQUI_STUDIO_TOKEN environment variable.
|
||||
|
||||
Once you have a valid API token in place, the studio speakers will be displayed as distinct models within the list.
|
||||
Once you have a valid API token in place, the studio speakers will be displayed as distinct models within the list.
|
||||
These models will follow the naming convention `coqui_studio/en/<studio_speaker_name>/coqui_studio`
|
||||
|
||||
```python
|
||||
|
|
|
@ -2,6 +2,18 @@
|
|||
"tts_models": {
|
||||
"multilingual": {
|
||||
"multi-dataset": {
|
||||
"xtts_v1": {
|
||||
"description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
|
||||
"hf_url": [
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
|
||||
"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
|
||||
],
|
||||
"default_vocoder": null,
|
||||
"commit": "e9a1953e",
|
||||
"license": "Coqui Community Model License",
|
||||
"contact": "info@coqui.ai"
|
||||
},
|
||||
"your_tts": {
|
||||
"description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip",
|
||||
|
@ -881,4 +893,4 @@
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -105,6 +105,9 @@ class TTS(nn.Module):
|
|||
|
||||
@property
|
||||
def is_multi_lingual(self):
|
||||
# TODO: fix this
|
||||
if "xtts" in self.model_name:
|
||||
return True
|
||||
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
|
||||
return self.synthesizer.tts_model.language_manager.num_languages > 1
|
||||
return False
|
||||
|
|
|
@ -392,7 +392,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
if args.encoder_path is not None:
|
||||
encoder_path = args.encoder_path
|
||||
encoder_config_path = args.encoder_config_path
|
||||
|
||||
|
||||
device = args.device
|
||||
if args.use_cuda:
|
||||
device = "cuda"
|
||||
|
@ -459,7 +459,9 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
target_wav=args.target_wav,
|
||||
)
|
||||
elif model_dir is not None:
|
||||
wav = synthesizer.tts(args.text, speaker_name=args.speaker_idx)
|
||||
wav = synthesizer.tts(
|
||||
args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
|
||||
)
|
||||
|
||||
# save the results
|
||||
print(" > Saving output to {}".format(args.out_path))
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsConfig(BaseTTSConfig):
|
||||
"""Defines parameters for XTTS TTS model.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name. Do not change unless you know what you are doing.
|
||||
|
||||
model_args (XttsArgs):
|
||||
Model architecture arguments. Defaults to `XttsArgs()`.
|
||||
|
||||
audio (XttsAudioConfig):
|
||||
Audio processing configuration. Defaults to `XttsAudioConfig()`.
|
||||
|
||||
model_dir (str):
|
||||
Path to the folder that has all the XTTS models. Defaults to None.
|
||||
|
||||
temperature (float):
|
||||
Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`.
|
||||
|
||||
length_penalty (float):
|
||||
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length,
|
||||
which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative),
|
||||
length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
|
||||
|
||||
reperation_penalty (float):
|
||||
The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`.
|
||||
|
||||
top_p (float):
|
||||
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
||||
Defaults to `0.8`.
|
||||
|
||||
cond_free_k (float):
|
||||
Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf].
|
||||
As cond_free_k increases, the output becomes dominated by the conditioning-free signal.
|
||||
Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`.
|
||||
|
||||
diffusion_temperature (float):
|
||||
Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0
|
||||
are the "mean" prediction of the diffusion network and will sound bland and smeared.
|
||||
Defaults to `1.0`.
|
||||
|
||||
num_gpt_outputs (int):
|
||||
Number of samples taken from the autoregressive model, all of which are filtered using CLVP.
|
||||
As XTTS is a probabilistic model, more samples means a higher probability of creating something "great".
|
||||
Defaults to `16`.
|
||||
|
||||
decoder_iterations (int):
|
||||
Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine
|
||||
the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better,
|
||||
however. Defaults to `30`.
|
||||
|
||||
decoder_sampler (str):
|
||||
Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`.
|
||||
Note:
|
||||
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.configs.xtts_config import XttsConfig
|
||||
>>> config = XttsConfig()
|
||||
"""
|
||||
|
||||
model: str = "xtts"
|
||||
# model specific params
|
||||
model_args: XttsArgs = field(default_factory=XttsArgs)
|
||||
audio: XttsAudioConfig = field(default_factory=XttsAudioConfig)
|
||||
model_dir: str = None
|
||||
languages: List[str] = field(
|
||||
default_factory=lambda: ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"]
|
||||
)
|
||||
|
||||
# inference params
|
||||
temperature: float = 0.2
|
||||
length_penalty: float = 1.0
|
||||
repetition_penalty: float = 2.0
|
||||
top_k: int = 50
|
||||
top_p: float = 0.8
|
||||
cond_free_k: float = 2.0
|
||||
diffusion_temperature: float = 1.0
|
||||
num_gpt_outputs: int = 16
|
||||
decoder_iterations: int = 30
|
||||
decoder_sampler: str = "ddim"
|
|
@ -5,15 +5,14 @@ from tokenizers import Tokenizer
|
|||
|
||||
from TTS.tts.utils.text.cleaners import english_cleaners
|
||||
|
||||
DEFAULT_VOCAB_FILE = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
|
||||
)
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
||||
def __init__(self, vocab_file=None, vocab_str=None):
|
||||
self.tokenizer = None
|
||||
if vocab_file is not None:
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
if vocab_str is not None:
|
||||
self.tokenizer = Tokenizer.from_str(vocab_str)
|
||||
|
||||
def preprocess_text(self, txt):
|
||||
txt = english_cleaners(txt)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,393 @@
|
|||
import functools
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
import torch.distributed as distributed
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
out = fn(model, *args, **kwargs)
|
||||
model.train(was_training)
|
||||
return out
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def dvae_wav_to_mel(
|
||||
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
||||
):
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
norm="slaney",
|
||||
).to(device)
|
||||
wav = wav.to(device)
|
||||
mel = mel_stft(wav)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
if mel_norms is None:
|
||||
mel_norms = torch.load(mel_norms_file, map_location=device)
|
||||
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||
return mel
|
||||
|
||||
|
||||
class Quantize(nn.Module):
|
||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.n_embed = n_embed
|
||||
self.decay = decay
|
||||
self.eps = eps
|
||||
|
||||
self.balancing_heuristic = balancing_heuristic
|
||||
self.codes = None
|
||||
self.max_codes = 64000
|
||||
self.codes_full = False
|
||||
self.new_return_order = new_return_order
|
||||
|
||||
embed = torch.randn(dim, n_embed)
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("cluster_size", torch.zeros(n_embed))
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
def forward(self, input, return_soft_codes=False):
|
||||
if self.balancing_heuristic and self.codes_full:
|
||||
h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes)
|
||||
mask = torch.logical_or(h > 0.9, h < 0.01).unsqueeze(1)
|
||||
ep = self.embed.permute(1, 0)
|
||||
ea = self.embed_avg.permute(1, 0)
|
||||
rand_embed = torch.randn_like(ep) * mask
|
||||
self.embed = (ep * ~mask + rand_embed).permute(1, 0)
|
||||
self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0)
|
||||
self.cluster_size = self.cluster_size * ~mask.squeeze()
|
||||
if torch.any(mask):
|
||||
print(f"Reset {torch.sum(mask)} embedding codes.")
|
||||
self.codes = None
|
||||
self.codes_full = False
|
||||
|
||||
flatten = input.reshape(-1, self.dim)
|
||||
dist = flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + self.embed.pow(2).sum(0, keepdim=True)
|
||||
soft_codes = -dist
|
||||
_, embed_ind = soft_codes.max(1)
|
||||
embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
|
||||
embed_ind = embed_ind.view(*input.shape[:-1])
|
||||
quantize = self.embed_code(embed_ind)
|
||||
|
||||
if self.balancing_heuristic:
|
||||
if self.codes is None:
|
||||
self.codes = embed_ind.flatten()
|
||||
else:
|
||||
self.codes = torch.cat([self.codes, embed_ind.flatten()])
|
||||
if len(self.codes) > self.max_codes:
|
||||
self.codes = self.codes[-self.max_codes :]
|
||||
self.codes_full = True
|
||||
|
||||
if self.training:
|
||||
embed_onehot_sum = embed_onehot.sum(0)
|
||||
embed_sum = flatten.transpose(0, 1) @ embed_onehot
|
||||
|
||||
if distributed.is_initialized() and distributed.get_world_size() > 1:
|
||||
distributed.all_reduce(embed_onehot_sum)
|
||||
distributed.all_reduce(embed_sum)
|
||||
|
||||
self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
|
||||
self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
|
||||
n = self.cluster_size.sum()
|
||||
cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
diff = (quantize.detach() - input).pow(2).mean()
|
||||
quantize = input + (quantize - input).detach()
|
||||
|
||||
if return_soft_codes:
|
||||
return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,))
|
||||
elif self.new_return_order:
|
||||
return quantize, embed_ind, diff
|
||||
else:
|
||||
return quantize, diff, embed_ind
|
||||
|
||||
def embed_code(self, embed_id):
|
||||
return F.embedding(embed_id, self.embed.transpose(0, 1))
|
||||
|
||||
|
||||
# Fits a soft-discretized input to a normal-PDF across the specified dimension.
|
||||
# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete
|
||||
# values with the specified expected variance.
|
||||
class DiscretizationLoss(nn.Module):
|
||||
def __init__(self, discrete_bins, dim, expected_variance, store_past=0):
|
||||
super().__init__()
|
||||
self.discrete_bins = discrete_bins
|
||||
self.dim = dim
|
||||
self.dist = torch.distributions.Normal(0, scale=expected_variance)
|
||||
if store_past > 0:
|
||||
self.record_past = True
|
||||
self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device="cpu"))
|
||||
self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device="cpu"))
|
||||
self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins))
|
||||
else:
|
||||
self.record_past = False
|
||||
|
||||
def forward(self, x):
|
||||
other_dims = set(range(len(x.shape))) - set([self.dim])
|
||||
averaged = x.sum(dim=tuple(other_dims)) / x.sum()
|
||||
averaged = averaged - averaged.mean()
|
||||
|
||||
if self.record_past:
|
||||
acc_count = self.accumulator.shape[0]
|
||||
avg = averaged.detach().clone()
|
||||
if self.accumulator_filled > 0:
|
||||
averaged = torch.mean(self.accumulator, dim=0) * (acc_count - 1) / acc_count + averaged / acc_count
|
||||
|
||||
# Also push averaged into the accumulator.
|
||||
self.accumulator[self.accumulator_index] = avg
|
||||
self.accumulator_index += 1
|
||||
if self.accumulator_index >= acc_count:
|
||||
self.accumulator_index *= 0
|
||||
if self.accumulator_filled <= 0:
|
||||
self.accumulator_filled += 1
|
||||
|
||||
return torch.sum(-self.dist.log_prob(averaged))
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, chan, conv, activation):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
conv(chan, chan, 3, padding=1),
|
||||
activation(),
|
||||
conv(chan, chan, 3, padding=1),
|
||||
activation(),
|
||||
conv(chan, chan, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
|
||||
class UpsampledConv(nn.Module):
|
||||
def __init__(self, conv, *args, **kwargs):
|
||||
super().__init__()
|
||||
assert "stride" in kwargs.keys()
|
||||
self.stride = kwargs["stride"]
|
||||
del kwargs["stride"]
|
||||
self.conv = conv(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
up = nn.functional.interpolate(x, scale_factor=self.stride, mode="nearest")
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
# DiscreteVAE partially derived from lucidrains DALLE implementation
|
||||
# Credit: https://github.com/lucidrains/DALLE-pytorch
|
||||
class DiscreteVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
positional_dims=2,
|
||||
num_tokens=512,
|
||||
codebook_dim=512,
|
||||
num_layers=3,
|
||||
num_resnet_blocks=0,
|
||||
hidden_dim=64,
|
||||
channels=3,
|
||||
stride=2,
|
||||
kernel_size=4,
|
||||
use_transposed_convs=True,
|
||||
encoder_norm=False,
|
||||
activation="relu",
|
||||
smooth_l1_loss=False,
|
||||
straight_through=False,
|
||||
normalization=None, # ((0.5,) * 3, (0.5,) * 3),
|
||||
record_codes=False,
|
||||
discretization_loss_averaging_steps=100,
|
||||
lr_quantizer_args={},
|
||||
):
|
||||
super().__init__()
|
||||
has_resblocks = num_resnet_blocks > 0
|
||||
|
||||
self.num_tokens = num_tokens
|
||||
self.num_layers = num_layers
|
||||
self.straight_through = straight_through
|
||||
self.positional_dims = positional_dims
|
||||
self.discrete_loss = DiscretizationLoss(
|
||||
num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps
|
||||
)
|
||||
|
||||
assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now.
|
||||
if positional_dims == 2:
|
||||
conv = nn.Conv2d
|
||||
conv_transpose = nn.ConvTranspose2d
|
||||
else:
|
||||
conv = nn.Conv1d
|
||||
conv_transpose = nn.ConvTranspose1d
|
||||
if not use_transposed_convs:
|
||||
conv_transpose = functools.partial(UpsampledConv, conv)
|
||||
|
||||
if activation == "relu":
|
||||
act = nn.ReLU
|
||||
elif activation == "silu":
|
||||
act = nn.SiLU
|
||||
else:
|
||||
assert NotImplementedError()
|
||||
|
||||
enc_layers = []
|
||||
dec_layers = []
|
||||
|
||||
if num_layers > 0:
|
||||
enc_chans = [hidden_dim * 2**i for i in range(num_layers)]
|
||||
dec_chans = list(reversed(enc_chans))
|
||||
|
||||
enc_chans = [channels, *enc_chans]
|
||||
|
||||
dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0]
|
||||
dec_chans = [dec_init_chan, *dec_chans]
|
||||
|
||||
enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans))
|
||||
|
||||
pad = (kernel_size - 1) // 2
|
||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act()))
|
||||
if encoder_norm:
|
||||
enc_layers.append(nn.GroupNorm(8, enc_out))
|
||||
dec_layers.append(
|
||||
nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride=stride, padding=pad), act())
|
||||
)
|
||||
dec_out_chans = dec_chans[-1]
|
||||
innermost_dim = dec_chans[0]
|
||||
else:
|
||||
enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act()))
|
||||
dec_out_chans = hidden_dim
|
||||
innermost_dim = hidden_dim
|
||||
|
||||
for _ in range(num_resnet_blocks):
|
||||
dec_layers.insert(0, ResBlock(innermost_dim, conv, act))
|
||||
enc_layers.append(ResBlock(innermost_dim, conv, act))
|
||||
|
||||
if num_resnet_blocks > 0:
|
||||
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
||||
|
||||
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
||||
dec_layers.append(conv(dec_out_chans, channels, 1))
|
||||
|
||||
self.encoder = nn.Sequential(*enc_layers)
|
||||
self.decoder = nn.Sequential(*dec_layers)
|
||||
|
||||
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
|
||||
self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True)
|
||||
|
||||
# take care of normalization within class
|
||||
self.normalization = normalization
|
||||
self.record_codes = record_codes
|
||||
if record_codes:
|
||||
self.codes = torch.zeros((1228800,), dtype=torch.long)
|
||||
self.code_ind = 0
|
||||
self.total_codes = 0
|
||||
self.internal_step = 0
|
||||
|
||||
def norm(self, images):
|
||||
if not self.normalization is not None:
|
||||
return images
|
||||
|
||||
means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
|
||||
arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()"
|
||||
means, stds = map(lambda t: rearrange(t, arrange), (means, stds))
|
||||
images = images.clone()
|
||||
images.sub_(means).div_(stds)
|
||||
return images
|
||||
|
||||
def get_debug_values(self, step, __):
|
||||
if self.record_codes and self.total_codes > 0:
|
||||
# Report annealing schedule
|
||||
return {"histogram_codes": self.codes[: self.total_codes]}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
def get_codebook_indices(self, images):
|
||||
img = self.norm(images)
|
||||
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
||||
sampled, codes, _ = self.codebook(logits)
|
||||
self.log_codes(codes)
|
||||
return codes
|
||||
|
||||
def decode(self, img_seq):
|
||||
self.log_codes(img_seq)
|
||||
if hasattr(self.codebook, "embed_code"):
|
||||
image_embeds = self.codebook.embed_code(img_seq)
|
||||
else:
|
||||
image_embeds = F.embedding(img_seq, self.codebook.codebook)
|
||||
b, n, d = image_embeds.shape
|
||||
|
||||
kwargs = {}
|
||||
if self.positional_dims == 1:
|
||||
arrange = "b n d -> b d n"
|
||||
else:
|
||||
h = w = int(sqrt(n))
|
||||
arrange = "b (h w) d -> b d h w"
|
||||
kwargs = {"h": h, "w": w}
|
||||
image_embeds = rearrange(image_embeds, arrange, **kwargs)
|
||||
images = [image_embeds]
|
||||
for layer in self.decoder:
|
||||
images.append(layer(images[-1]))
|
||||
return images[-1], images[-2]
|
||||
|
||||
def infer(self, img):
|
||||
img = self.norm(img)
|
||||
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
||||
sampled, codes, commitment_loss = self.codebook(logits)
|
||||
return self.decode(codes)
|
||||
|
||||
# Note: This module is not meant to be run in forward() except while training. It has special logic which performs
|
||||
# evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially
|
||||
# more lossy (but useful for determining network performance).
|
||||
def forward(self, img):
|
||||
img = self.norm(img)
|
||||
logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1))
|
||||
sampled, codes, commitment_loss = self.codebook(logits)
|
||||
sampled = sampled.permute((0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1))
|
||||
|
||||
if self.training:
|
||||
out = sampled
|
||||
for d in self.decoder:
|
||||
out = d(out)
|
||||
self.log_codes(codes)
|
||||
else:
|
||||
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||
out, _ = self.decode(codes)
|
||||
|
||||
# reconstruction loss
|
||||
recon_loss = self.loss_fn(img, out, reduction="none")
|
||||
|
||||
return recon_loss, commitment_loss, out
|
||||
|
||||
def log_codes(self, codes):
|
||||
# This is so we can debug the distribution of codes being learned.
|
||||
if self.record_codes and self.internal_step % 10 == 0:
|
||||
codes = codes.flatten()
|
||||
l = codes.shape[0]
|
||||
i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l
|
||||
self.codes[i : i + l] = codes.cpu()
|
||||
self.code_ind = self.code_ind + l
|
||||
if self.code_ind >= self.codes.shape[0]:
|
||||
self.code_ind = 0
|
||||
self.total_codes += 1
|
||||
self.internal_step += 1
|
|
@ -0,0 +1,545 @@
|
|||
# ported from: https://github.com/neonbjb/tortoise-tts
|
||||
|
||||
import functools
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config
|
||||
|
||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_dim, init=0.02, relative=False):
|
||||
super().__init__()
|
||||
# nn.Embedding
|
||||
self.emb = torch.nn.Embedding(seq_len, model_dim)
|
||||
# Initializing this way is standard for GPT-2
|
||||
self.emb.weight.data.normal_(mean=0.0, std=init)
|
||||
self.relative = relative
|
||||
self.seq_len = seq_len
|
||||
|
||||
def forward(self, x):
|
||||
sl = x.shape[1]
|
||||
if self.relative:
|
||||
start = random.randint(sl, self.seq_len) - sl
|
||||
return self.emb(torch.arange(start, start + sl, device=x.device))
|
||||
else:
|
||||
return self.emb(torch.arange(0, sl, device=x.device))
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
heads,
|
||||
max_mel_seq_len,
|
||||
max_text_seq_len,
|
||||
max_prompt_len,
|
||||
checkpointing,
|
||||
):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
from transformers import GPT2Config, GPT2Model
|
||||
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=256, # Unused.
|
||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing,
|
||||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
# Override the built in positional embeddings
|
||||
del gpt.wpe
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
|
||||
mel_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
text_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_dim)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_dim)
|
||||
)
|
||||
# gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
|
||||
return gpt, mel_pos_emb, text_pos_emb, None, None
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
start_text_token=261,
|
||||
stop_text_token=0,
|
||||
layers=8,
|
||||
model_dim=512,
|
||||
heads=8,
|
||||
max_text_tokens=120,
|
||||
max_mel_tokens=250,
|
||||
max_prompt_tokens=70,
|
||||
max_conditioning_inputs=1,
|
||||
code_stride_len=1024,
|
||||
number_text_tokens=256,
|
||||
num_audio_tokens=8194,
|
||||
start_audio_token=8192,
|
||||
stop_audio_token=8193,
|
||||
train_solo_embeddings=False,
|
||||
checkpointing=False,
|
||||
average_conditioning_embeddings=False,
|
||||
label_smoothing=0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.label_smoothing = label_smoothing
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = start_text_token
|
||||
self.stop_text_token = stop_text_token
|
||||
self.num_audio_tokens = num_audio_tokens
|
||||
self.start_audio_token = start_audio_token
|
||||
self.stop_audio_token = stop_audio_token
|
||||
self.start_prompt_token = start_audio_token
|
||||
self.stop_prompt_token = stop_audio_token
|
||||
self.layers = layers
|
||||
self.heads = heads
|
||||
self.model_dim = model_dim
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
|
||||
self.max_prompt_tokens = max_prompt_tokens
|
||||
self.code_stride_len = code_stride_len
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.conditioning_dropout = nn.Dropout1d(0.1)
|
||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
|
||||
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
||||
|
||||
(
|
||||
self.gpt,
|
||||
self.mel_pos_embedding,
|
||||
self.text_pos_embedding,
|
||||
self.mel_layer_pos_embedding,
|
||||
self.text_layer_pos_embedding,
|
||||
) = build_hf_gpt_transformer(
|
||||
layers,
|
||||
model_dim,
|
||||
heads,
|
||||
self.max_mel_tokens,
|
||||
self.max_text_tokens,
|
||||
self.max_prompt_tokens,
|
||||
checkpointing,
|
||||
)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||
else:
|
||||
self.mel_solo_embedding = 0
|
||||
self.text_solo_embedding = 0
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||
"gpt": list(self.gpt.parameters()),
|
||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
||||
def init_gpt_for_inference(self, kv_cache=True):
|
||||
seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_mel_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.model_dim,
|
||||
n_layer=self.layers,
|
||||
n_head=self.heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
)
|
||||
self.gpt_inference = GPT2InferenceModel(
|
||||
gpt_config,
|
||||
self.gpt,
|
||||
self.mel_pos_embedding,
|
||||
self.mel_embedding,
|
||||
self.final_norm,
|
||||
self.mel_head,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
self.gpt.wte = self.mel_embedding
|
||||
|
||||
def set_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
return inp, tar
|
||||
|
||||
def set_mel_padding(self, mel_input_tokens, code_lengths):
|
||||
"""
|
||||
Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
|
||||
that audio clip, reformats the tokens with stop_audio_token in place of the zero padding. This is required
|
||||
preformatting to create a working TTS model.
|
||||
"""
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
for b in range(len(code_lengths)):
|
||||
actual_end = code_lengths[b]
|
||||
if actual_end < mel_input_tokens.shape[-1]:
|
||||
mel_input_tokens[b, actual_end:] = self.stop_audio_token
|
||||
return mel_input_tokens
|
||||
|
||||
def get_logits(
|
||||
self,
|
||||
first_inputs,
|
||||
first_head,
|
||||
second_inputs=None,
|
||||
second_head=None,
|
||||
prompt=None,
|
||||
get_attns=False,
|
||||
return_latent=False,
|
||||
attn_mask_text=None,
|
||||
attn_mask_mel=None,
|
||||
):
|
||||
if prompt is not None:
|
||||
offset = prompt.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([prompt, first_inputs], dim=1)
|
||||
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
attn_mask = None
|
||||
if attn_mask_text is not None:
|
||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||
if prompt is not None:
|
||||
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
return_dict=True,
|
||||
output_attentions=get_attns,
|
||||
attention_mask=attn_mask,
|
||||
)
|
||||
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
|
||||
enc = gpt_out.last_hidden_state[:, offset:]
|
||||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
|
||||
|
||||
first_logits = enc[:, : first_inputs.shape[1]]
|
||||
first_logits = first_head(first_logits)
|
||||
first_logits = first_logits.permute(0, 2, 1)
|
||||
if second_inputs is not None:
|
||||
second_logits = enc[:, -second_inputs.shape[1] :]
|
||||
second_logits = second_head(second_logits)
|
||||
second_logits = second_logits.permute(0, 2, 1)
|
||||
return first_logits, second_logits
|
||||
else:
|
||||
return first_logits
|
||||
|
||||
def get_conditioning(self, speech_conditioning_input):
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
||||
def get_prompts(self, prompt_codes):
|
||||
"""
|
||||
Create a prompt from the mel codes. This is used to condition the model on the mel codes.
|
||||
Pad the prompt with start and stop mel tokens.
|
||||
"""
|
||||
prompt = prompt_codes
|
||||
if self.training:
|
||||
lengths = []
|
||||
# Compute the real prompt length based on the first encounter with the token 83 used for padding
|
||||
for i in range(prompt_codes.shape[0]):
|
||||
length = 0
|
||||
for j in range(prompt_codes.shape[1]):
|
||||
if prompt_codes[i, j] == 83:
|
||||
break
|
||||
else:
|
||||
length += 1
|
||||
lengths.append(length)
|
||||
|
||||
# prompt_len = random.randint(1, 9) # in secs
|
||||
prompt_len = 3
|
||||
prompt_len = prompt_len * 24 # in frames
|
||||
if prompt_codes.shape[-1] >= prompt_len:
|
||||
new_prompt = []
|
||||
for i in range(prompt_codes.shape[0]):
|
||||
if lengths[i] < prompt_len:
|
||||
start = 0
|
||||
else:
|
||||
start = random.randint(0, lengths[i] - prompt_len)
|
||||
prompt = prompt_codes[:, start : start + prompt_len]
|
||||
|
||||
# add start and stop tokens
|
||||
prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token)
|
||||
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
|
||||
return prompt
|
||||
|
||||
def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
|
||||
"""
|
||||
cond_input: (b, 80, s) or (b, 1, 80, s)
|
||||
conds: (b, 1024, s)
|
||||
"""
|
||||
conds = None
|
||||
if not return_latent:
|
||||
if cond_input.ndim == 4:
|
||||
cond_input = cond_input.squeeze(1)
|
||||
if sample:
|
||||
_len_secs = random.randint(2, 6) # in secs
|
||||
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
|
||||
if cond_input.shape[-1] >= cond_seg_len:
|
||||
new_conds = []
|
||||
for i in range(cond_input.shape[0]):
|
||||
cond_len = int(cond_lens[i] / 1024)
|
||||
if cond_len < cond_seg_len:
|
||||
start = 0
|
||||
else:
|
||||
start = random.randint(0, cond_len - cond_seg_len)
|
||||
cond_vec = cond_input[i, :, start : start + cond_seg_len]
|
||||
new_conds.append(cond_vec)
|
||||
conds = torch.stack(new_conds, dim=0)
|
||||
else:
|
||||
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
|
||||
cond_frame_len = int((22050 / 1024) * cond_seg_len)
|
||||
conds = cond_input[:, :, -cond_frame_len:]
|
||||
|
||||
conds = self.conditioning_encoder(conds)
|
||||
else:
|
||||
# already computed
|
||||
conds = cond_input.unsqueeze(1)
|
||||
return conds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_inputs,
|
||||
text_lengths,
|
||||
audio_codes,
|
||||
wav_lengths,
|
||||
cond_lens=None,
|
||||
cond_mels=None,
|
||||
cond_latents=None,
|
||||
loss_weights=None,
|
||||
return_attentions=False,
|
||||
return_latent=False,
|
||||
):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
||||
cond_mels: MEL float tensor, (b, 1, 80,s)
|
||||
text_inputs: long tensor, (b,t)
|
||||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
|
||||
If return_attentions is specified, only logits are returned.
|
||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||
"""
|
||||
# ❗ FIXIT
|
||||
if self.max_conditioning_inputs == 0:
|
||||
assert cond_mels is None, " ❗ cond_mels is not None, but max_conditioning_inputs == 0"
|
||||
|
||||
max_text_len = text_lengths.max()
|
||||
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = code_lengths.max()
|
||||
|
||||
if max_mel_len > audio_codes.shape[-1]:
|
||||
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
||||
|
||||
silence = True
|
||||
for idx, l in enumerate(code_lengths):
|
||||
length = l.item()
|
||||
while silence:
|
||||
if audio_codes[idx, length - 1] != 83:
|
||||
break
|
||||
length -= 1
|
||||
code_lengths[idx] = length
|
||||
|
||||
# 💖 Lovely assertions
|
||||
assert (
|
||||
max_mel_len <= audio_codes.shape[-1]
|
||||
), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
|
||||
assert (
|
||||
max_text_len <= text_inputs.shape[-1]
|
||||
), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
|
||||
|
||||
# Append stop token to text inputs
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
|
||||
|
||||
# Append silence token to mel codes
|
||||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)
|
||||
|
||||
# Pad mel codes with stop_audio_token
|
||||
audio_codes = self.set_mel_padding(audio_codes, code_lengths)
|
||||
|
||||
# Build input and target tensors
|
||||
# Prepend start token to inputs and append stop token to targets
|
||||
text_inputs, text_targets = self.set_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token
|
||||
)
|
||||
audio_codes, mel_targets = self.set_inputs_and_targets(
|
||||
audio_codes, self.start_audio_token, self.stop_audio_token
|
||||
)
|
||||
|
||||
# Set attn_mask
|
||||
attn_mask_text = None
|
||||
attn_mask_mel = None
|
||||
if not return_latent:
|
||||
attn_mask_text = torch.ones(
|
||||
text_inputs.shape[0],
|
||||
text_inputs.shape[1],
|
||||
dtype=torch.bool,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
attn_mask_mel = torch.ones(
|
||||
audio_codes.shape[0],
|
||||
audio_codes.shape[1],
|
||||
dtype=torch.bool,
|
||||
device=audio_codes.device,
|
||||
)
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
||||
for idx, l in enumerate(code_lengths):
|
||||
attn_mask_mel[idx, l + 1 :] = 0.0
|
||||
|
||||
# Compute text embeddings + positional embeddings
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
# Compute mel embeddings + positional embeddings
|
||||
mel_emb = self.mel_embedding(audio_codes) + self.mel_pos_embedding(audio_codes)
|
||||
|
||||
# Compute speech conditioning input
|
||||
if cond_latents is None:
|
||||
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
|
||||
|
||||
# Get logits
|
||||
sub = -5 # don't ask me why 😄
|
||||
if self.training:
|
||||
sub = -1
|
||||
|
||||
text_logits, mel_logits = self.get_logits(
|
||||
text_emb,
|
||||
self.text_head,
|
||||
mel_emb,
|
||||
self.mel_head,
|
||||
prompt=cond_latents,
|
||||
get_attns=return_attentions,
|
||||
return_latent=return_latent,
|
||||
attn_mask_text=attn_mask_text,
|
||||
attn_mask_mel=attn_mask_mel,
|
||||
)
|
||||
if return_latent:
|
||||
return mel_logits[:, :sub] # sub to prevent bla.
|
||||
|
||||
if return_attentions:
|
||||
return mel_logits
|
||||
|
||||
# Set paddings to -1 to ignore them in loss
|
||||
for idx, l in enumerate(text_lengths):
|
||||
text_targets[idx, l + 1 :] = -1
|
||||
|
||||
for idx, l in enumerate(code_lengths):
|
||||
mel_targets[idx, l + 1 :] = -1
|
||||
|
||||
# check if stoptoken is in every row of mel_targets
|
||||
assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[
|
||||
0
|
||||
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
|
||||
|
||||
# Compute losses
|
||||
loss_text = F.cross_entropy(
|
||||
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
|
||||
)
|
||||
loss_mel = F.cross_entropy(
|
||||
mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
|
||||
)
|
||||
return loss_text.mean(), loss_mel.mean(), mel_logits
|
||||
|
||||
def inference(self, cond_latents, text_inputs, **hf_generate_kwargs):
|
||||
self.compute_embeddings(cond_latents, text_inputs)
|
||||
return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
cond_latents,
|
||||
text_inputs,
|
||||
):
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
||||
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
emb = torch.cat([cond_latents, emb], dim=1)
|
||||
self.gpt_inference.store_prefix_emb(emb)
|
||||
gpt_inputs = torch.full(
|
||||
(
|
||||
emb.shape[0],
|
||||
emb.shape[1] + 1, # +1 for the start_audio_token
|
||||
),
|
||||
fill_value=1,
|
||||
dtype=torch.long,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
gpt_inputs[:, -1] = self.start_audio_token
|
||||
return gpt_inputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
cond_latents,
|
||||
text_inputs,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
gpt_inputs = self.compute_embeddings(cond_latents, text_inputs)
|
||||
gen = self.gpt_inference.generate(
|
||||
gpt_inputs,
|
||||
bos_token_id=self.start_audio_token,
|
||||
pad_token_id=self.stop_audio_token,
|
||||
eos_token_id=self.stop_audio_token,
|
||||
max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
return gen.sequences[:, gpt_inputs.shape[1] :], gen
|
||||
return gen[:, gpt_inputs.shape[1] :]
|
|
@ -0,0 +1,658 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
||||
|
||||
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
||||
super().__init__(config)
|
||||
self.transformer = gpt
|
||||
self.pos_embedding = pos_emb
|
||||
self.embeddings = embeddings
|
||||
self.final_norm = norm
|
||||
self.lm_head = nn.Sequential(norm, linear)
|
||||
self.kv_cache = kv_cache
|
||||
|
||||
def store_prefix_emb(self, prefix_emb):
|
||||
self.cached_prefix_emb = prefix_emb
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
||||
if not self.kv_cache:
|
||||
past_key_values = None
|
||||
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values is not None:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
assert self.cached_prefix_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
|
||||
|
||||
# Create embedding
|
||||
prefix_len = self.cached_prefix_emb.shape[1]
|
||||
if input_ids.shape[1] != 1:
|
||||
gen_inputs = input_ids[:, prefix_len:]
|
||||
gen_emb = self.embeddings(gen_inputs)
|
||||
gen_emb = gen_emb + self.pos_embedding(gen_emb)
|
||||
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
|
||||
prefix_emb = self.cached_prefix_emb.repeat_interleave(
|
||||
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
|
||||
)
|
||||
else:
|
||||
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
|
||||
emb = torch.cat([prefix_emb, gen_emb], dim=1)
|
||||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + transformer_outputs[1:]
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=None,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
def __init__(self, seq_len, model_channels, init_std=0.02, relative=False):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(seq_len, model_channels)
|
||||
nn.init.normal_(self.emb.weight, mean=0.0, std=init_std)
|
||||
self.relative = relative
|
||||
|
||||
def forward(self, x):
|
||||
seq_len = x.shape[1]
|
||||
if self.relative:
|
||||
start = torch.randint(seq_len, (1,), device=x.device).item()
|
||||
positions = torch.arange(start, start + seq_len, device=x.device)
|
||||
else:
|
||||
positions = torch.arange(seq_len, device=x.device)
|
||||
return self.emb(positions)
|
||||
|
||||
def get_fixed_embedding(self, ind, dev):
|
||||
return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
|
||||
|
||||
|
||||
def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing):
|
||||
"""
|
||||
Initializes a GPT-2 model and its position embeddings for a text-to-speech system.
|
||||
|
||||
Args:
|
||||
layers (int): Number of layers in the GPT-2 model.
|
||||
model_channels (int): Dimension of the GPT-2 model.
|
||||
heads (int): Number of heads in the GPT-2 model.
|
||||
max_mel_seq_len (int): Maximum sequence length for the mel spectrogram.
|
||||
max_text_seq_len (int): Maximum sequence length for the text.
|
||||
max_prompt_len (int): Maximum length of the prompt.
|
||||
checkpointing (bool): Whether to use gradient checkpointing.
|
||||
|
||||
Returns:
|
||||
gpt (GPT2Model): GPT-2 model.
|
||||
mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram.
|
||||
text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text.
|
||||
"""
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=123,
|
||||
n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
|
||||
n_embd=model_channels,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing,
|
||||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
|
||||
del gpt.wpe
|
||||
del gpt.wte
|
||||
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels)
|
||||
|
||||
audio_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_mel_seq_len, model_channels)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_channels)
|
||||
)
|
||||
text_pos_emb = (
|
||||
LearnedPositionEmbeddings(max_text_seq_len, model_channels)
|
||||
if max_mel_seq_len != -1
|
||||
else functools.partial(null_position_embeddings, dim=model_channels)
|
||||
)
|
||||
|
||||
return gpt, audio_pos_emb, text_pos_emb
|
||||
|
||||
|
||||
class XTTSGPTEncoder(nn.Module):
|
||||
"""XTTS GPT Encoder model implementation.
|
||||
Args:
|
||||
start_text_token (int): Index of the start token in the text vocabulary.
|
||||
stop_text_token (int): Index of the stop token in the text vocabulary.
|
||||
n_layers (int): Number of layers in the GPT-2 model.
|
||||
n_model_channels (int): Dimension of the GPT-2 model.
|
||||
n_heads (int): Number of heads in the GPT-2 model.
|
||||
max_text_tokens (int): Maximum number of text tokens.
|
||||
max_audio_tokens (int): Maximum number of audio tokens.
|
||||
max_prompt_tokens (int): Maximum number of prompt tokens.
|
||||
audio_len_compression (int): Compression factor for the audio length.
|
||||
number_text_tokens (int): Number of text tokens.
|
||||
number_audio_codes (int): Number of audio codes.
|
||||
start_mel_token (int): Index of the start token in the mel code vocabulary.
|
||||
stop_mel_token (int): Index of the stop token in the mel code vocabulary.
|
||||
checkpointing (bool): Whether or not to use gradient checkpointing at training.
|
||||
"""
|
||||
|
||||
_inference_flag = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_text_token=261,
|
||||
stop_text_token=0,
|
||||
n_layers=8,
|
||||
n_model_channels=512,
|
||||
n_heads=8,
|
||||
max_text_tokens=120,
|
||||
max_audio_tokens=250,
|
||||
max_prompt_tokens=70,
|
||||
audio_len_compression=1024,
|
||||
number_text_tokens=256,
|
||||
number_audio_codes=8194,
|
||||
start_mel_token=8192,
|
||||
stop_mel_token=8193,
|
||||
checkpointing=True,
|
||||
label_smoothing=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.label_smoothing = label_smoothing
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = start_text_token
|
||||
self.stop_text_token = stop_text_token
|
||||
self.number_audio_codes = number_audio_codes
|
||||
self.start_mel_token = start_mel_token
|
||||
self.stop_mel_token = stop_mel_token
|
||||
self.start_prompt_token = start_mel_token
|
||||
self.stop_prompt_token = stop_mel_token
|
||||
self.n_layers = n_layers
|
||||
self.n_heads = n_heads
|
||||
self.n_model_channels = n_model_channels
|
||||
self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs
|
||||
self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
|
||||
self.max_prompt_tokens = max_prompt_tokens
|
||||
self.audio_len_compression = audio_len_compression
|
||||
|
||||
# embedding layers
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels)
|
||||
self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
|
||||
self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
|
||||
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels)
|
||||
|
||||
# initialize the GPT-2 model
|
||||
(
|
||||
self.gpt,
|
||||
self.audio_pos_embedding,
|
||||
self.text_pos_embedding,
|
||||
) = init_gpt(
|
||||
n_layers,
|
||||
n_model_channels,
|
||||
n_heads,
|
||||
self.max_audio_tokens,
|
||||
self.max_text_tokens,
|
||||
self.max_prompt_tokens,
|
||||
checkpointing,
|
||||
)
|
||||
|
||||
# output layers
|
||||
self.final_norm = nn.LayerNorm(n_model_channels)
|
||||
self.text_head = nn.Linear(n_model_channels, self.number_text_tokens)
|
||||
self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||
"gpt": list(self.gpt.parameters()),
|
||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
||||
def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
|
||||
self._inference_flag = True
|
||||
seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens
|
||||
gpt_config = GPT2Config(
|
||||
vocab_size=self.max_audio_tokens,
|
||||
n_positions=seq_length,
|
||||
n_ctx=seq_length,
|
||||
n_embd=self.n_model_channels,
|
||||
n_layer=self.n_layers,
|
||||
n_head=self.n_heads,
|
||||
gradient_checkpointing=False,
|
||||
use_cache=True,
|
||||
)
|
||||
self.inference_model = GPT2InferenceModel(
|
||||
gpt_config,
|
||||
self.gpt,
|
||||
self.audio_pos_embedding,
|
||||
self.audio_embedding,
|
||||
self.final_norm,
|
||||
self.mel_head,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
self.gpt.wte = self.audio_embedding
|
||||
|
||||
def set_inputs_and_targets(self, input, start_token, stop_token):
|
||||
inp = F.pad(input, (1, 0), value=start_token)
|
||||
tar = F.pad(input, (0, 1), value=stop_token)
|
||||
return inp, tar
|
||||
|
||||
def set_audio_tokens_padding(self, audio_tokens, audio_token_lens):
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
for b in range(len(audio_token_lens)):
|
||||
actual_end = audio_token_lens[b]
|
||||
if actual_end < audio_tokens.shape[-1]:
|
||||
audio_tokens[b, actual_end:] = self.stop_mel_token
|
||||
return audio_tokens
|
||||
|
||||
def get_logits(
|
||||
self,
|
||||
speech_conditioning_inputs,
|
||||
first_inputs,
|
||||
first_head,
|
||||
second_inputs=None,
|
||||
second_head=None,
|
||||
prompt=None,
|
||||
get_attns=False,
|
||||
return_latent=False,
|
||||
attn_mask_text=None,
|
||||
attn_mask_mel=None,
|
||||
):
|
||||
if prompt is not None and speech_conditioning_inputs is not None:
|
||||
offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat(
|
||||
[speech_conditioning_inputs, prompt, first_inputs, second_inputs],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
|
||||
elif speech_conditioning_inputs is not None:
|
||||
offset = speech_conditioning_inputs.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
||||
elif prompt is not None:
|
||||
offset = prompt.shape[1]
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([prompt, first_inputs], dim=1)
|
||||
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
attn_mask = None
|
||||
if attn_mask_text is not None:
|
||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||
if prompt is not None:
|
||||
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
return_dict=True,
|
||||
output_attentions=get_attns,
|
||||
attention_mask=attn_mask,
|
||||
)
|
||||
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
|
||||
enc = gpt_out.last_hidden_state[:, offset:]
|
||||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
|
||||
|
||||
first_logits = enc[:, : first_inputs.shape[1]]
|
||||
first_logits = first_head(first_logits)
|
||||
first_logits = first_logits.permute(0, 2, 1)
|
||||
if second_inputs is not None:
|
||||
second_logits = enc[:, -second_inputs.shape[1] :]
|
||||
second_logits = second_head(second_logits)
|
||||
second_logits = second_logits.permute(0, 2, 1)
|
||||
return first_logits, second_logits
|
||||
else:
|
||||
return first_logits
|
||||
|
||||
def get_conditioning(self, speech_conditioning_input):
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
||||
def get_prompts(self, prompt_codes):
|
||||
prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token)
|
||||
prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token)
|
||||
return prompt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_inputs,
|
||||
text_lengths,
|
||||
audio_codes,
|
||||
wav_lengths,
|
||||
prompt_codes,
|
||||
return_attentions=False,
|
||||
return_latent=False,
|
||||
):
|
||||
max_text_len = text_lengths.max()
|
||||
|
||||
# Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
|
||||
# Like [..., 186, 45, 45, 83] where actually it should end with 186.
|
||||
# We take last 3 codes to prevent abrupt ending of the audio.
|
||||
# TODO: This is might need some testing.
|
||||
mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = mel_lengths.max()
|
||||
|
||||
if max_mel_len > audio_codes.shape[-1]:
|
||||
audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
|
||||
|
||||
# silence aware lengths, skip the silence tokens at the end of the mel codes.
|
||||
silence = True
|
||||
for idx, l in enumerate(mel_lengths):
|
||||
length = l.item()
|
||||
while silence:
|
||||
if audio_codes[idx, length - 1] != 83:
|
||||
break
|
||||
length -= 1
|
||||
mel_lengths[idx] = length
|
||||
|
||||
# Lovely assertions
|
||||
assert (
|
||||
max_mel_len <= audio_codes.shape[-1]
|
||||
), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
|
||||
assert (
|
||||
max_text_len <= text_inputs.shape[-1]
|
||||
), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
|
||||
|
||||
# Append stop token to text inputs
|
||||
text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
|
||||
|
||||
# Append silence token to mel codes
|
||||
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
|
||||
|
||||
# Pad mel codes with STOP_MEL_TOKEN
|
||||
audio_codes = self.set_mel_padding(audio_codes, mel_lengths)
|
||||
|
||||
# Compute speech conditioning input
|
||||
conds = None
|
||||
if speech_conditioning_input is not None:
|
||||
if not return_latent:
|
||||
# Compute speech conditioning input
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
if self.average_conditioning_embeddings:
|
||||
conds = conds.mean(dim=1).unsqueeze(1)
|
||||
else:
|
||||
# already computed
|
||||
conds = speech_conditioning_input.unsqueeze(1)
|
||||
|
||||
# Build input and target tensors
|
||||
# Prepend start token to inputs and append stop token to targets
|
||||
text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||
audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token)
|
||||
|
||||
# Set attn_mask
|
||||
attn_mask_text = None
|
||||
attn_mask_mel = None
|
||||
if not return_latent:
|
||||
attn_mask_text = torch.ones(
|
||||
text_inputs.shape[0],
|
||||
text_inputs.shape[1],
|
||||
dtype=torch.bool,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
attn_mask_mel = torch.ones(
|
||||
audio_codes.shape[0],
|
||||
audio_codes.shape[1],
|
||||
dtype=torch.bool,
|
||||
device=audio_codes.device,
|
||||
)
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
||||
for idx, l in enumerate(mel_lengths):
|
||||
attn_mask_mel[idx, l + 1 :] = 0.0
|
||||
|
||||
# Compute text embeddings + positional embeddings
|
||||
# print(" > text input latent:", text_inputs)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
# Compute mel embeddings + positional embeddings
|
||||
audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes)
|
||||
|
||||
# Compute prompt embeddings + positional embeddings
|
||||
prompt = self.get_prompts(prompt_codes)
|
||||
|
||||
# prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
|
||||
prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt)
|
||||
|
||||
# dropout prompt embeddings
|
||||
prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training)
|
||||
|
||||
# Get logits
|
||||
sub = -4 # don't ask me why 😄
|
||||
if self.training:
|
||||
sub = -1
|
||||
_, audio_logits = self.get_logits(
|
||||
conds,
|
||||
text_emb,
|
||||
self.text_head,
|
||||
audio_emb,
|
||||
self.mel_head,
|
||||
prompt=prompt_emb,
|
||||
get_attns=return_attentions,
|
||||
return_latent=return_latent,
|
||||
attn_mask_text=attn_mask_text,
|
||||
attn_mask_mel=attn_mask_mel,
|
||||
)
|
||||
return audio_logits[:, :sub] # sub to prevent bla.
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
speech_conditioning_latent,
|
||||
text_inputs,
|
||||
input_tokens=None,
|
||||
prompt_codes=None,
|
||||
pad_input_text=False,
|
||||
):
|
||||
"""Compute all the embeddings needed for inference."""
|
||||
if pad_input_text and text_inputs.shape[1] < 250:
|
||||
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
|
||||
else:
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
||||
|
||||
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
print(" > Text inputs:", text_inputs)
|
||||
if prompt_codes is not None:
|
||||
prompt_codes = self.get_prompts(prompt_codes)
|
||||
# prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
|
||||
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
|
||||
|
||||
print(" > Prompt inputs:", prompt_codes)
|
||||
print(" > Prompt inputs shape:", prompt_codes.shape)
|
||||
emb = torch.cat([prompt_emb, emb], dim=1)
|
||||
|
||||
if speech_conditioning_latent is not None:
|
||||
conds = speech_conditioning_latent.unsqueeze(1)
|
||||
emb = torch.cat([conds, emb], dim=1)
|
||||
|
||||
self.inference_model.store_prefix_emb(emb)
|
||||
|
||||
fake_inputs = torch.full(
|
||||
(
|
||||
emb.shape[0],
|
||||
emb.shape[1] + 1, # +1 for the start_mel_token
|
||||
),
|
||||
fill_value=1,
|
||||
dtype=torch.long,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
if input_tokens is not None:
|
||||
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
return fake_inputs
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text_inputs,
|
||||
input_tokens=None,
|
||||
prompt_codes=None,
|
||||
pad_input_text=False,
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
if pad_input_text and text_inputs.shape[1] < 250:
|
||||
text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
|
||||
else:
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
|
||||
|
||||
emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
if prompt_codes is not None:
|
||||
prompt_codes = self.get_prompts(prompt_codes)
|
||||
prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
|
||||
emb = torch.cat([prompt_emb, emb], dim=1)
|
||||
|
||||
self.inference_model.store_prefix_emb(emb)
|
||||
|
||||
fake_inputs = torch.full(
|
||||
(
|
||||
emb.shape[0],
|
||||
emb.shape[1] + 1, # +1 for the start_mel_token
|
||||
),
|
||||
fill_value=1,
|
||||
dtype=torch.long,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
fake_inputs[:, -1] = self.start_mel_token
|
||||
|
||||
if input_tokens is not None:
|
||||
fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
|
||||
gen = self.inference_model.generate(
|
||||
fake_inputs,
|
||||
bos_token_id=self.start_mel_token,
|
||||
pad_token_id=self.stop_mel_token,
|
||||
eos_token_id=self.stop_mel_token,
|
||||
max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
if "return_dict_in_generate" in hf_generate_kwargs:
|
||||
return gen.sequences[:, fake_inputs.shape[1] :], gen
|
||||
return gen[:, fake_inputs.shape[1] :]
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,136 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
|
||||
class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||
"""Override GPT2LMHeadModel to allow for prefix conditioning."""
|
||||
|
||||
def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
|
||||
super().__init__(config)
|
||||
self.transformer = gpt
|
||||
self.pos_embedding = pos_emb
|
||||
self.embeddings = embeddings
|
||||
self.final_norm = norm
|
||||
self.lm_head = nn.Sequential(norm, linear)
|
||||
self.kv_cache = kv_cache
|
||||
|
||||
def store_prefix_emb(self, prefix_emb):
|
||||
self.cached_prefix_emb = prefix_emb
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None) # usually None
|
||||
if not self.kv_cache:
|
||||
past_key_values = None
|
||||
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values is not None:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values is not None:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
assert self.cached_prefix_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
|
||||
|
||||
# Create embedding
|
||||
prefix_len = self.cached_prefix_emb.shape[1]
|
||||
if input_ids.shape[1] != 1:
|
||||
gen_inputs = input_ids[:, prefix_len:]
|
||||
gen_emb = self.embeddings(gen_inputs)
|
||||
gen_emb = gen_emb + self.pos_embedding(gen_emb)
|
||||
if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
|
||||
prefix_emb = self.cached_prefix_emb.repeat_interleave(
|
||||
gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
|
||||
)
|
||||
else:
|
||||
prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
|
||||
emb = torch.cat([prefix_emb, gen_emb], dim=1)
|
||||
else:
|
||||
emb = self.embeddings(input_ids)
|
||||
emb = emb + self.pos_embedding.get_fixed_embedding(
|
||||
attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
|
||||
)
|
||||
transformer_outputs = self.transformer(
|
||||
inputs_embeds=emb,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + transformer_outputs[1:]
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=None,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
|
@ -0,0 +1,141 @@
|
|||
# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
groups = 32
|
||||
if channels <= 16:
|
||||
groups = 8
|
||||
elif channels <= 64:
|
||||
groups = 16
|
||||
while channels % groups != 0:
|
||||
groups = int(groups / 2)
|
||||
assert groups > 2
|
||||
return GroupNorm32(groups, channels)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, mask=None, qk_bias=0):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = weight + qk_bias
|
||||
if mask is not None:
|
||||
mask = mask.repeat(self.n_heads, 1, 1)
|
||||
weight[mask.logical_not()] = -torch.inf
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = torch.einsum("bts,bcs->bct", weight, v)
|
||||
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""An attention block that allows spatial positions to attend to each other."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
out_channels=None,
|
||||
do_activation=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
out_channels = channels if out_channels is None else out_channels
|
||||
self.do_activation = do_activation
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, out_channels * 3, 1)
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1)
|
||||
self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1))
|
||||
|
||||
def forward(self, x, mask=None, qk_bias=0):
|
||||
b, c, *spatial = x.shape
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
||||
if mask.shape[1] != x.shape[-1]:
|
||||
mask = mask[:, : x.shape[-1], : x.shape[-1]]
|
||||
|
||||
x = x.reshape(b, c, -1)
|
||||
x = self.norm(x)
|
||||
if self.do_activation:
|
||||
x = F.silu(x, inplace=True)
|
||||
qkv = self.qkv(x)
|
||||
h = self.attention(qkv, mask=mask, qk_bias=qk_bias)
|
||||
h = self.proj_out(h)
|
||||
xp = self.x_proj(x)
|
||||
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
||||
|
||||
|
||||
class ConditioningEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_dim,
|
||||
embedding_dim,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=4,
|
||||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (b, 80, s)
|
||||
"""
|
||||
h = self.init(x)
|
||||
h = self.attn(h)
|
||||
return h
|
|
@ -0,0 +1,286 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import inflect
|
||||
import pandas as pd
|
||||
import pypinyin
|
||||
import torch
|
||||
from num2words import num2words
|
||||
from tokenizers import Tokenizer
|
||||
from unidecode import unidecode
|
||||
|
||||
from TTS.tts.utils.text.cleaners import english_cleaners
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace(".", " point ")
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + " dollars" # Unexpected format
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return "%s %s" % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s" % (cents, cent_unit)
|
||||
else:
|
||||
return "zero dollars"
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return "two thousand"
|
||||
elif num > 2000 and num < 2010:
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r"\1 pounds", text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', "")
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers_multilang(text, lang):
|
||||
# TODO: Handle text more carefully. Currently, it just converts numbers without any context.
|
||||
# Find all numbers in the input string
|
||||
numbers = re.findall(r"\d+", text)
|
||||
|
||||
# Transliterate the numbers to text
|
||||
for num in numbers:
|
||||
transliterated_num = "".join(num2words(num, lang=lang))
|
||||
text = text.replace(num, transliterated_num, 1)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def multilingual_cleaners(text, lang):
|
||||
text = lowercase(text)
|
||||
text = expand_numbers_multilang(text, lang)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', "")
|
||||
if lang == "tr":
|
||||
text = text.replace("İ", "i")
|
||||
text = text.replace("Ö", "ö")
|
||||
text = text.replace("Ü", "ü")
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
text = text.replace('"', "")
|
||||
return text
|
||||
|
||||
|
||||
def remove_extraneous_punctuation(word):
|
||||
replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "—": "-", "—": "-", "`": "'", "ʼ": "'"}
|
||||
replace = re.compile(
|
||||
"|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL
|
||||
)
|
||||
word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
|
||||
|
||||
# TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
|
||||
extraneous = re.compile(r"^[@#%_=\$\^&\*\+\\]$")
|
||||
word = extraneous.sub("", word)
|
||||
return word
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def arabic_cleaners(text):
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def chinese_cleaners(text):
|
||||
text = lowercase(text)
|
||||
text = "".join(
|
||||
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
class VoiceBpeTokenizer:
|
||||
def __init__(self, vocab_file=None, preprocess=None):
|
||||
self.tokenizer = None
|
||||
|
||||
if vocab_file is not None:
|
||||
with open(vocab_file, "r", encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
self.language = vocab["model"]["language"] if "language" in vocab["model"] else None
|
||||
|
||||
if preprocess is None:
|
||||
self.preprocess = "pre_tokenizer" in vocab and vocab["pre_tokenizer"]
|
||||
else:
|
||||
self.preprocess = preprocess
|
||||
|
||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||
|
||||
def preprocess_text(self, txt, lang):
|
||||
if lang == "ja":
|
||||
import pykakasi
|
||||
|
||||
kks = pykakasi.kakasi()
|
||||
results = kks.convert(txt)
|
||||
txt = " ".join([result["kana"] for result in results])
|
||||
txt = basic_cleaners(txt)
|
||||
elif lang == "en":
|
||||
txt = english_cleaners(txt)
|
||||
elif lang == "ar":
|
||||
txt = arabic_cleaners(txt)
|
||||
elif lang == "zh-cn":
|
||||
txt = chinese_cleaners(txt)
|
||||
else:
|
||||
txt = multilingual_cleaners(txt, lang)
|
||||
return txt
|
||||
|
||||
def encode(self, txt, lang):
|
||||
if self.preprocess:
|
||||
txt = self.preprocess_text(txt, lang)
|
||||
txt = txt.replace(" ", "[SPACE]")
|
||||
return self.tokenizer.encode(txt).ids
|
||||
|
||||
def decode(self, seq):
|
||||
if isinstance(seq, torch.Tensor):
|
||||
seq = seq.cpu().numpy()
|
||||
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
|
||||
txt = txt.replace("[SPACE]", " ")
|
||||
txt = txt.replace("[STOP]", "")
|
||||
txt = txt.replace("[UNK]", "")
|
||||
return txt
|
|
@ -0,0 +1,385 @@
|
|||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
class KernelPredictor(torch.nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int): number of layers
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_convs = nn.ModuleList()
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
for _ in range(3):
|
||||
self.residual_convs.append(
|
||||
nn.Sequential(
|
||||
nn.Dropout(kpnet_dropout),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
)
|
||||
self.kernel_conv = nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_kernel_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
self.bias_conv = nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
kpnet_bias_channels,
|
||||
kpnet_conv_size,
|
||||
padding=padding,
|
||||
bias=True,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
c = self.input_conv(c)
|
||||
for residual_conv in self.residual_convs:
|
||||
residual_conv.to(c.device)
|
||||
c = c + residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
kernels = k.contiguous().view(
|
||||
batch,
|
||||
self.conv_layers,
|
||||
self.conv_in_channels,
|
||||
self.conv_out_channels,
|
||||
self.conv_kernel_size,
|
||||
cond_length,
|
||||
)
|
||||
bias = b.contiguous().view(
|
||||
batch,
|
||||
self.conv_layers,
|
||||
self.conv_out_channels,
|
||||
cond_length,
|
||||
)
|
||||
|
||||
return kernels, bias
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.input_conv[0])
|
||||
nn.utils.remove_weight_norm(self.kernel_conv)
|
||||
nn.utils.remove_weight_norm(self.bias_conv)
|
||||
for block in self.residual_convs:
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
nn.utils.remove_weight_norm(block[3])
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
stride,
|
||||
dilations=[1, 3, 9, 27],
|
||||
lReLU_slope=0.2,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = len(dilations)
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=len(dilations),
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
||||
)
|
||||
|
||||
self.convt_pre = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
2 * stride,
|
||||
stride=stride,
|
||||
padding=stride // 2 + stride % 2,
|
||||
output_padding=stride % 2,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.conv_blocks = nn.ModuleList()
|
||||
for dilation in dilations:
|
||||
self.conv_blocks.append(
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
conv_kernel_size,
|
||||
padding=dilation * (conv_kernel_size - 1) // 2,
|
||||
dilation=dilation,
|
||||
)
|
||||
),
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
_, in_channels, _ = x.shape # (B, c_g, L')
|
||||
|
||||
x = self.convt_pre(x) # (B, c_g, stride * L')
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
for i, conv in enumerate(self.conv_blocks):
|
||||
output = conv(x) # (B, c_g, stride * L')
|
||||
|
||||
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
||||
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
||||
|
||||
output = self.location_variable_convolution(
|
||||
output, k, b, hop_size=self.cond_hop_length
|
||||
) # (B, 2 * c_g, stride * L'): LVC
|
||||
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
||||
output[:, in_channels:, :]
|
||||
) # (B, c_g, stride * L'): GAU
|
||||
|
||||
return x
|
||||
|
||||
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o.to(memory_format=torch.channels_last_3d)
|
||||
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
||||
o = o + bias
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
|
||||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.kernel_predictor.remove_weight_norm()
|
||||
nn.utils.remove_weight_norm(self.convt_pre[1])
|
||||
for block in self.conv_blocks:
|
||||
nn.utils.remove_weight_norm(block[1])
|
||||
|
||||
|
||||
class UnivNetGenerator(nn.Module):
|
||||
"""
|
||||
UnivNet Generator
|
||||
|
||||
Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
noise_dim=64,
|
||||
channel_size=32,
|
||||
dilations=[1, 3, 9, 27],
|
||||
strides=[8, 8, 4],
|
||||
lReLU_slope=0.2,
|
||||
kpnet_conv_size=3,
|
||||
# Below are MEL configurations options that this generator requires.
|
||||
hop_length=256,
|
||||
n_mel_channels=100,
|
||||
):
|
||||
super(UnivNetGenerator, self).__init__()
|
||||
self.mel_channel = n_mel_channels
|
||||
self.noise_dim = noise_dim
|
||||
self.hop_length = hop_length
|
||||
channel_size = channel_size
|
||||
kpnet_conv_size = kpnet_conv_size
|
||||
|
||||
self.res_stack = nn.ModuleList()
|
||||
hop_length = 1
|
||||
for stride in strides:
|
||||
hop_length = stride * hop_length
|
||||
self.res_stack.append(
|
||||
LVCBlock(
|
||||
channel_size,
|
||||
n_mel_channels,
|
||||
stride=stride,
|
||||
dilations=dilations,
|
||||
lReLU_slope=lReLU_slope,
|
||||
cond_hop_length=hop_length,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
)
|
||||
)
|
||||
|
||||
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
|
||||
|
||||
self.conv_post = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, c, z):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length)
|
||||
z (Tensor): the noise sequence (batch, noise_dim, in_length)
|
||||
|
||||
"""
|
||||
z = self.conv_pre(z) # (B, c_g, L)
|
||||
|
||||
for res_block in self.res_stack:
|
||||
res_block.to(z.device)
|
||||
z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i)
|
||||
|
||||
z = self.conv_post(z) # (B, 1, L * 256)
|
||||
|
||||
return z
|
||||
|
||||
def eval(self, inference=False):
|
||||
super(UnivNetGenerator, self).eval()
|
||||
# don't remove weight norm while validation in training loop
|
||||
if inference:
|
||||
self.remove_weight_norm()
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.conv_pre)
|
||||
|
||||
for layer in self.conv_post:
|
||||
if len(layer.state_dict()) != 0:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
|
||||
for res_block in self.res_stack:
|
||||
res_block.remove_weight_norm()
|
||||
|
||||
def inference(self, c, z=None):
|
||||
# pad input mel with zeros to cut artifact
|
||||
# see https://github.com/seungwonpark/melgan/issues/8
|
||||
zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device)
|
||||
mel = torch.cat((c, zero), dim=2)
|
||||
|
||||
if z is None:
|
||||
z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device)
|
||||
|
||||
audio = self.forward(mel, z)
|
||||
audio = audio[:, :, : -(self.hop_length * 10)]
|
||||
audio = audio.clamp(min=-1, max=1)
|
||||
return audio
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = UnivNetGenerator()
|
||||
|
||||
c = torch.randn(3, 100, 10)
|
||||
z = torch.randn(3, 64, 10)
|
||||
print(c.shape)
|
||||
|
||||
y = model(c, z)
|
||||
print(y.shape)
|
||||
assert y.shape == torch.Size([3, 1, 2560])
|
||||
|
||||
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(pytorch_total_params)
|
|
@ -0,0 +1,654 @@
|
|||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel
|
||||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
||||
from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
def load_audio(audiopath, sr=22050):
|
||||
"""
|
||||
Load an audio file from disk and resample it to the specified sampling rate.
|
||||
|
||||
Args:
|
||||
audiopath (str): Path to the audio file.
|
||||
sr (int): Target sampling rate.
|
||||
|
||||
Returns:
|
||||
Tensor: Audio waveform tensor with shape (1, T), where T is the number of samples.
|
||||
"""
|
||||
audio, sampling_rate = torchaudio.load(audiopath)
|
||||
|
||||
if len(audio.shape) > 1:
|
||||
if audio.shape[0] < 5:
|
||||
audio = audio[0]
|
||||
else:
|
||||
assert audio.shape[1] < 5
|
||||
audio = audio[:, 0]
|
||||
|
||||
if sampling_rate != sr:
|
||||
resampler = torchaudio.transforms.Resample(sampling_rate, sr)
|
||||
audio = resampler(audio)
|
||||
|
||||
audio = audio.clamp_(-1, 1)
|
||||
return audio.unsqueeze(0)
|
||||
|
||||
|
||||
def wav_to_mel_cloning(
|
||||
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
||||
):
|
||||
"""
|
||||
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
|
||||
|
||||
Args:
|
||||
wav (torch.Tensor): Input waveform tensor.
|
||||
mel_norms_file (str): Path to mel-spectrogram normalization file.
|
||||
mel_norms (torch.Tensor): Mel-spectrogram normalization tensor.
|
||||
device (torch.device): Device to use for computation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Mel-spectrogram tensor.
|
||||
"""
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
norm="slaney",
|
||||
).to(device)
|
||||
wav = wav.to(device)
|
||||
mel = mel_stft(wav)
|
||||
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
if mel_norms is None:
|
||||
mel_norms = torch.load(mel_norms_file, map_location=device)
|
||||
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
|
||||
return mel
|
||||
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
"""
|
||||
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): The input tensor to be padded or truncated.
|
||||
length (int): The desired length of the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The padded or truncated tensor.
|
||||
"""
|
||||
tp = t[..., :length]
|
||||
if t.shape[-1] == length:
|
||||
tp = t
|
||||
elif t.shape[-1] < length:
|
||||
tp = F.pad(t, (0, length - t.shape[-1]))
|
||||
return tp
|
||||
|
||||
|
||||
def load_discrete_vocoder_diffuser(
|
||||
trained_diffusion_steps=4000,
|
||||
desired_diffusion_steps=200,
|
||||
cond_free=True,
|
||||
cond_free_k=1,
|
||||
sampler="ddim",
|
||||
):
|
||||
"""
|
||||
Load a GaussianDiffusion instance configured for use as a decoder.
|
||||
|
||||
Args:
|
||||
trained_diffusion_steps (int): The number of diffusion steps used during training.
|
||||
desired_diffusion_steps (int): The number of diffusion steps to use during inference.
|
||||
cond_free (bool): Whether to use a conditioning-free model.
|
||||
cond_free_k (int): The number of samples to use for conditioning-free models.
|
||||
sampler (str): The name of the sampler to use.
|
||||
|
||||
Returns:
|
||||
A SpacedDiffusion instance configured with the given parameters.
|
||||
"""
|
||||
return SpacedDiffusion(
|
||||
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
|
||||
model_mean_type="epsilon",
|
||||
model_var_type="learned_range",
|
||||
loss_type="mse",
|
||||
betas=get_named_beta_schedule("linear", trained_diffusion_steps),
|
||||
conditioning_free=cond_free,
|
||||
conditioning_free_k=cond_free_k,
|
||||
sampler=sampler,
|
||||
)
|
||||
|
||||
|
||||
def do_spectrogram_diffusion(
|
||||
diffusion_model,
|
||||
diffuser,
|
||||
latents,
|
||||
conditioning_latents,
|
||||
temperature=1,
|
||||
):
|
||||
"""
|
||||
Generate a mel-spectrogram using a diffusion model and a diffuser.
|
||||
|
||||
Args:
|
||||
diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||
diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise.
|
||||
latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes.
|
||||
conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes.
|
||||
temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
output_seq_len = (
|
||||
latents.shape[1] * 4 * 24000 // 22050
|
||||
) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
|
||||
output_shape = (latents.shape[0], 100, output_seq_len)
|
||||
precomputed_embeddings = diffusion_model.timestep_independent(
|
||||
latents, conditioning_latents, output_seq_len, False
|
||||
)
|
||||
|
||||
noise = torch.randn(output_shape, device=latents.device) * temperature
|
||||
mel = diffuser.sample_loop(
|
||||
diffusion_model,
|
||||
output_shape,
|
||||
noise=noise,
|
||||
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
|
||||
progress=False,
|
||||
)
|
||||
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(Coqpit):
|
||||
"""
|
||||
Configuration class for audio-related parameters in the XTTS model.
|
||||
|
||||
Args:
|
||||
sample_rate (int): The sample rate in which the GPT operates.
|
||||
diffusion_sample_rate (int): The sample rate of the diffusion audio waveform.
|
||||
output_sample_rate (int): The sample rate of the output audio waveform.
|
||||
"""
|
||||
|
||||
sample_rate: int = 22050
|
||||
diffusion_sample_rate: int = 24000
|
||||
output_sample_rate: int = 24000
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsArgs(Coqpit):
|
||||
"""A dataclass to represent XTTS model arguments that define the model structure.
|
||||
|
||||
Args:
|
||||
gpt_batch_size (int): The size of the auto-regressive batch.
|
||||
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
|
||||
lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
|
||||
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
|
||||
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
|
||||
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
||||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||
vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
|
||||
|
||||
For GPT model:
|
||||
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||
ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
|
||||
ar_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
|
||||
ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
|
||||
ar_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
|
||||
ar_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
|
||||
ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
|
||||
ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
|
||||
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
|
||||
ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
|
||||
|
||||
For DiffTTS model:
|
||||
diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024.
|
||||
diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10.
|
||||
diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100.
|
||||
diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200.
|
||||
diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024.
|
||||
diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193.
|
||||
diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0.
|
||||
diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False.
|
||||
diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16.
|
||||
diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0.
|
||||
diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0.
|
||||
"""
|
||||
|
||||
gpt_batch_size: int = 1
|
||||
enable_redaction: bool = False
|
||||
lazy_load: bool = True
|
||||
kv_cache: bool = True
|
||||
gpt_checkpoint: str = None
|
||||
clvp_checkpoint: str = None
|
||||
decoder_checkpoint: str = None
|
||||
num_chars: int = 255
|
||||
|
||||
# XTTS GPT Encoder params
|
||||
tokenizer_file: str = ""
|
||||
gpt_max_audio_tokens: int = 605
|
||||
gpt_max_text_tokens: int = 402
|
||||
gpt_max_prompt_tokens: int = 70
|
||||
gpt_layers: int = 30
|
||||
gpt_n_model_channels: int = 1024
|
||||
gpt_n_heads: int = 16
|
||||
gpt_number_text_tokens: int = None
|
||||
gpt_start_text_token: int = None
|
||||
gpt_stop_text_token: int = None
|
||||
gpt_num_audio_tokens: int = 8194
|
||||
gpt_start_audio_token: int = 8192
|
||||
gpt_stop_audio_token: int = 8193
|
||||
|
||||
# Diffusion Decoder params
|
||||
diff_model_channels: int = 1024
|
||||
diff_num_layers: int = 10
|
||||
diff_in_channels: int = 100
|
||||
diff_out_channels: int = 200
|
||||
diff_in_latent_channels: int = 1024
|
||||
diff_in_tokens: int = 8193
|
||||
diff_dropout: int = 0
|
||||
diff_use_fp16: bool = False
|
||||
diff_num_heads: int = 16
|
||||
diff_layer_drop: int = 0
|
||||
diff_unconditioned_percentage: int = 0
|
||||
|
||||
# constants
|
||||
duration_const: int = 102400
|
||||
|
||||
|
||||
class Xtts(BaseTTS):
|
||||
"""ⓍTTS model implementation.
|
||||
|
||||
❗ Currently it only supports inference.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.configs.xtts_config import XttsConfig
|
||||
>>> from TTS.tts.models.xtts import Xtts
|
||||
>>> config = XttsConfig()
|
||||
>>> model = Xtts.inif_from_config(config)
|
||||
>>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True)
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config, ap=None, tokenizer=None)
|
||||
self.lazy_load = self.args.lazy_load
|
||||
self.mel_stats_path = None
|
||||
self.config = config
|
||||
self.gpt_checkpoint = self.args.gpt_checkpoint
|
||||
self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed
|
||||
self.models_dir = config.model_dir
|
||||
self.gpt_batch_size = self.args.gpt_batch_size
|
||||
|
||||
self.tokenizer = VoiceBpeTokenizer()
|
||||
self.gpt = None
|
||||
self.diffusion_decoder = None
|
||||
self.init_models()
|
||||
self.register_buffer("mel_stats", torch.ones(80))
|
||||
|
||||
def init_models(self):
|
||||
"""Initialize the models. We do it here since we need to load the tokenizer first."""
|
||||
if self.tokenizer.tokenizer is not None:
|
||||
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
|
||||
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
|
||||
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
|
||||
|
||||
if self.args.gpt_number_text_tokens:
|
||||
self.gpt = GPT(
|
||||
layers=self.args.gpt_layers,
|
||||
model_dim=self.args.gpt_n_model_channels,
|
||||
start_text_token=self.args.gpt_start_text_token,
|
||||
stop_text_token=self.args.gpt_stop_text_token,
|
||||
heads=self.args.gpt_n_heads,
|
||||
max_text_tokens=self.args.gpt_max_text_tokens,
|
||||
max_mel_tokens=self.args.gpt_max_audio_tokens,
|
||||
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
|
||||
number_text_tokens=self.args.gpt_number_text_tokens,
|
||||
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
||||
start_audio_token=self.args.gpt_start_audio_token,
|
||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
)
|
||||
|
||||
self.diffusion_decoder = DiffusionTts(
|
||||
model_channels=self.args.diff_model_channels,
|
||||
num_layers=self.args.diff_num_layers,
|
||||
in_channels=self.args.diff_in_channels,
|
||||
out_channels=self.args.diff_out_channels,
|
||||
in_latent_channels=self.args.diff_in_latent_channels,
|
||||
in_tokens=self.args.diff_in_tokens,
|
||||
dropout=self.args.diff_dropout,
|
||||
use_fp16=self.args.diff_use_fp16,
|
||||
num_heads=self.args.diff_num_heads,
|
||||
layer_drop=self.args.diff_layer_drop,
|
||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
||||
)
|
||||
|
||||
self.vocoder = UnivNetGenerator()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@contextmanager
|
||||
def lazy_load_model(self, model):
|
||||
"""Context to load a model on demand.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
"""
|
||||
if self.lazy_load:
|
||||
yield model
|
||||
else:
|
||||
m = model.to(self.device)
|
||||
yield m
|
||||
m = model.cpu()
|
||||
|
||||
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||
|
||||
Args:
|
||||
audio_path (str): Path to the audio file.
|
||||
length (int): Length of the audio in seconds. Defaults to 3.
|
||||
"""
|
||||
|
||||
audio = load_audio(audio_path)
|
||||
audio = audio[:, : 22050 * length]
|
||||
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
def get_diffusion_cond_latents(
|
||||
self,
|
||||
audio_path,
|
||||
):
|
||||
from math import ceil
|
||||
|
||||
diffusion_conds = []
|
||||
CHUNK_SIZE = 102400
|
||||
audio = load_audio(audio_path, 24000)
|
||||
for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)):
|
||||
current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE]
|
||||
current_sample = pad_or_truncate(current_sample, CHUNK_SIZE)
|
||||
cond_mel = wav_to_univnet_mel(
|
||||
current_sample.to(self.device),
|
||||
do_normalization=False,
|
||||
device=self.device,
|
||||
)
|
||||
diffusion_conds.append(cond_mel)
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
|
||||
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
|
||||
return diffusion_latent
|
||||
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
audio_path,
|
||||
gpt_cond_len=3,
|
||||
):
|
||||
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
|
||||
diffusion_cond_latents = self.get_diffusion_cond_latents(
|
||||
audio_path,
|
||||
)
|
||||
return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)
|
||||
|
||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||
"""Synthesize speech with the given input text.
|
||||
|
||||
Args:
|
||||
text (str): Input text.
|
||||
config (XttsConfig): Config with inference parameters.
|
||||
speaker_wav (str): Path to the speaker audio file for cloning.
|
||||
language (str): Language ID of the speaker.
|
||||
**kwargs: Inference settings. See `inference()`.
|
||||
|
||||
Returns:
|
||||
A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference,
|
||||
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
|
||||
as latents used at inference.
|
||||
|
||||
"""
|
||||
|
||||
# Make the synthesizer happy 🥳
|
||||
if isinstance(speaker_wav, list):
|
||||
speaker_wav = speaker_wav[0]
|
||||
|
||||
return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs)
|
||||
|
||||
def inference_with_config(self, text, config, ref_audio_path, language, **kwargs):
|
||||
"""
|
||||
inference with config
|
||||
"""
|
||||
assert (
|
||||
language in self.config.languages
|
||||
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
|
||||
# Use generally found best tuning knobs for generation.
|
||||
settings = {
|
||||
"temperature": config.temperature,
|
||||
"length_penalty": config.length_penalty,
|
||||
"repetition_penalty": config.repetition_penalty,
|
||||
"top_k": config.top_k,
|
||||
"top_p": config.top_p,
|
||||
"cond_free_k": config.cond_free_k,
|
||||
"diffusion_temperature": config.diffusion_temperature,
|
||||
"decoder_iterations": config.decoder_iterations,
|
||||
"decoder_sampler": config.decoder_sampler,
|
||||
}
|
||||
settings.update(kwargs) # allow overriding of preset settings with kwargs
|
||||
return self.inference(text, ref_audio_path, language, **settings)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
text,
|
||||
ref_audio_path,
|
||||
language,
|
||||
# GPT inference
|
||||
temperature=0.65,
|
||||
length_penalty=1,
|
||||
repetition_penalty=2.0,
|
||||
top_k=50,
|
||||
top_p=0.85,
|
||||
gpt_cond_len=4,
|
||||
do_sample=True,
|
||||
# Decoder inference
|
||||
decoder_iterations=100,
|
||||
cond_free=True,
|
||||
cond_free_k=2,
|
||||
diffusion_temperature=1.0,
|
||||
decoder_sampler="ddim",
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
"""
|
||||
This function produces an audio clip of the given text being spoken with the given reference voice.
|
||||
|
||||
Args:
|
||||
text: (str) Text to be spoken.
|
||||
|
||||
ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3
|
||||
seconds long.
|
||||
|
||||
language: (str) Language of the voice to be generated.
|
||||
|
||||
temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65.
|
||||
|
||||
length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the
|
||||
model to produce more terse outputs. Defaults to 1.0.
|
||||
|
||||
repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during
|
||||
decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0.
|
||||
|
||||
top_k: (int) K value used in top-k sampling. [0,inf]. Lower values mean the decoder produces more "likely"
|
||||
(aka boring) outputs. Defaults to 50.
|
||||
|
||||
top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely"
|
||||
(aka boring) outputs. Defaults to 0.8.
|
||||
|
||||
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
||||
else the first `gpt_cond_len` secs is used. Defaults to 3 seconds.
|
||||
|
||||
decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has
|
||||
more chances to iteratively refine the output, which should theoretically mean a higher quality output.
|
||||
Generally a value above 250 is not noticeably better, however. Defaults to 100.
|
||||
|
||||
cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion
|
||||
performs two forward passes for each diffusion step: one with the outputs of the autoregressive model
|
||||
and one with no conditioning priors. The output of the two is blended according to the cond_free_k
|
||||
value below. Conditioning-free diffusion is the real deal, and dramatically improves realism.
|
||||
Defaults to True.
|
||||
|
||||
cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the
|
||||
conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the
|
||||
conditioning-free signal. Defaults to 2.0.
|
||||
|
||||
diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1].
|
||||
Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
|
||||
Defaults to 1.0.
|
||||
|
||||
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
|
||||
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
||||
here: https://huggingface.co/docs/transformers/internal/generation_utils
|
||||
|
||||
Returns:
|
||||
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
||||
Sample rate is 24kHz.
|
||||
"""
|
||||
text = f"[{language}]{text.strip().lower()}"
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||
|
||||
assert (
|
||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||
|
||||
(
|
||||
gpt_cond_latent,
|
||||
diffusion_conditioning,
|
||||
) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
|
||||
|
||||
diffuser = load_discrete_vocoder_diffuser(
|
||||
desired_diffusion_steps=decoder_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
sampler=decoder_sampler,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
self.gpt = self.gpt.to(self.device)
|
||||
with self.lazy_load_model(self.gpt) as gpt:
|
||||
gpt_codes = gpt.generate(
|
||||
cond_latents=gpt_cond_latent,
|
||||
text_inputs=text_tokens,
|
||||
input_tokens=None,
|
||||
do_sample=do_sample,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
num_return_sequences=self.gpt_batch_size,
|
||||
length_penalty=length_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
output_attentions=False,
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
|
||||
with self.lazy_load_model(self.gpt) as gpt:
|
||||
expected_output_len = torch.tensor(
|
||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||
)
|
||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||
gpt_latents = gpt(
|
||||
text_tokens,
|
||||
text_len,
|
||||
gpt_codes,
|
||||
expected_output_len,
|
||||
cond_latents=gpt_cond_latent,
|
||||
return_attentions=False,
|
||||
return_latent=True,
|
||||
)
|
||||
silence_token = 83
|
||||
ctokens = 0
|
||||
for k in range(gpt_codes.shape[-1]):
|
||||
if gpt_codes[0, k] == silence_token:
|
||||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if ctokens > 8:
|
||||
gpt_latents = gpt_latents[:, :k]
|
||||
break
|
||||
|
||||
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
|
||||
mel = do_spectrogram_diffusion(
|
||||
diffusion,
|
||||
diffuser,
|
||||
gpt_latents,
|
||||
diffusion_conditioning,
|
||||
temperature=diffusion_temperature,
|
||||
)
|
||||
with self.lazy_load_model(self.vocoder) as vocoder:
|
||||
wav = vocoder.inference(mel)
|
||||
|
||||
return {"wav": wav.cpu().numpy().squeeze()}
|
||||
|
||||
def forward(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
|
||||
def eval_step(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument
|
||||
return Xtts(config)
|
||||
|
||||
def eval(self): # pylint: disable=redefined-builtin
|
||||
"""Sets the model to evaluation mode. Overrides the default eval() method to also set the GPT model to eval mode."""
|
||||
self.gpt.init_gpt_for_inference()
|
||||
super().eval()
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True
|
||||
):
|
||||
"""
|
||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||
|
||||
Args:
|
||||
config (dict): The configuration dictionary for the model.
|
||||
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
||||
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
||||
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
||||
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False.
|
||||
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
||||
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
||||
|
||||
if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")):
|
||||
self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json"))
|
||||
|
||||
self.init_models()
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
self.load_state_dict(load_fsspec(model_path)["model"], strict=strict)
|
||||
|
||||
if eval:
|
||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||
self.gpt.eval()
|
||||
self.diffusion_decoder.eval()
|
||||
self.vocoder.eval()
|
||||
|
||||
def train_step(self):
|
||||
raise NotImplementedError("XTTS Training is not implemented")
|
|
@ -8,7 +8,9 @@ def init():
|
|||
import jpype
|
||||
import jpype.imports
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`.")
|
||||
raise ModuleNotFoundError(
|
||||
"Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`."
|
||||
)
|
||||
|
||||
try:
|
||||
jar_path = os.environ["BEL_FANETYKA_JAR"]
|
||||
|
@ -31,4 +33,5 @@ def belarusian_text_to_phonemes(text: str) -> str:
|
|||
init()
|
||||
|
||||
from org.alex73.fanetyka.impl import FanetykaText
|
||||
|
||||
return str(FanetykaText(finder, text).ipa)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer
|
||||
from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer
|
||||
from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak
|
||||
from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut
|
||||
from TTS.tts.utils.text.phonemizers.ko_kr_phonemizer import KO_KR_Phonemizer
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Dict
|
||||
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
|
||||
from TTS.tts.utils.text.phonemizers.base import BasePhonemizer
|
||||
|
||||
_DEF_BE_PUNCS = ",!." # TODO
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import datetime
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
@ -219,3 +220,22 @@ class KeepAverage:
|
|||
def update_values(self, value_dict):
|
||||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
|
||||
|
||||
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
|
||||
lg = logging.getLogger(logger_name)
|
||||
formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S")
|
||||
lg.setLevel(level)
|
||||
if tofile:
|
||||
log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp()))
|
||||
fh = logging.FileHandler(log_file, mode="w")
|
||||
fh.setFormatter(formatter)
|
||||
lg.addHandler(fh)
|
||||
if screen:
|
||||
sh = logging.StreamHandler()
|
||||
sh.setFormatter(formatter)
|
||||
lg.addHandler(sh)
|
||||
|
|
|
@ -334,7 +334,7 @@ class ModelManager(object):
|
|||
output_model_path = output_path
|
||||
output_config_path = None
|
||||
if (
|
||||
model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name
|
||||
model not in ["tortoise-v2", "bark", "xtts_v1"] and "fairseq" not in model_name
|
||||
): # TODO:This is stupid but don't care for now.
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
# update paths in the config.json
|
||||
|
|
|
@ -242,7 +242,11 @@ class Synthesizer(nn.Module):
|
|||
wav (List[int]): waveform as a list of values.
|
||||
path (str): output path to save the waveform.
|
||||
"""
|
||||
wav = np.array(wav)
|
||||
# if tensor convert to numpy
|
||||
if torch.is_tensor(wav):
|
||||
wav = wav.cpu().numpy()
|
||||
if isinstance(wav, list):
|
||||
wav = np.array(wav)
|
||||
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate)
|
||||
|
||||
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
|
||||
|
@ -334,7 +338,7 @@ class Synthesizer(nn.Module):
|
|||
|
||||
elif language_name and isinstance(language_name, str):
|
||||
try:
|
||||
language_id = self.tts_model.language_manager.name_to_id[language_name]
|
||||
language_id = self.tts_model.language_manager.name_to_id[language_id]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f" [!] Looks like you use a multi-lingual model. "
|
||||
|
@ -374,6 +378,8 @@ class Synthesizer(nn.Module):
|
|||
speaker_id=speaker_name,
|
||||
voice_dirs=self.voice_dir,
|
||||
d_vector=speaker_embedding,
|
||||
speaker_wav=speaker_wav,
|
||||
language=language_name,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -53,6 +53,7 @@
|
|||
models/overflow.md
|
||||
models/tortoise.md
|
||||
models/bark.md
|
||||
models/xtts.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Bark 🐶
|
||||
# 🐶 Bark
|
||||
|
||||
Bark is a multi-lingual TTS model created by [Suno-AI](https://www.suno.ai/). It can generate conversational speech as well as music and sound effects.
|
||||
It is architecturally very similar to Google's [AudioLM](https://arxiv.org/abs/2209.03143). For more information, please refer to the [Suno-AI's repo](https://github.com/suno-ai/bark).
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Tortoise 🐢
|
||||
# 🐢 Tortoise
|
||||
Tortoise is a very expressive TTS system with impressive voice cloning capabilities. It is based on an GPT like autogressive acoustic model that converts input
|
||||
text to discritized acouistic tokens, a diffusion model that converts these tokens to melspeectrogram frames and a Univnet vocoder to convert the spectrograms to
|
||||
the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS.
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# ⓍTTS
|
||||
ⓍTTS is a super cool Text-to-Speech model that lets you clone voices in different languages by using just a quick 3-second audio clip. Built on the 🐢Tortoise,
|
||||
ⓍTTS has important model changes that make cross-language voice cloning and multi-lingual speech generation super easy.
|
||||
There is no need for an excessive amount of training data that spans countless hours.
|
||||
|
||||
This is the same model that powers [Coqui Studio](https://coqui.ai/), and [Coqui API](https://docs.coqui.ai/docs), however we apply
|
||||
a few tricks to make it faster and support streaming inference.
|
||||
|
||||
### Features
|
||||
- Voice cloning with just a 3-second audio clip.
|
||||
- Cross-language voice cloning.
|
||||
- Multi-lingual speech generation.
|
||||
- 24khz sampling rate.
|
||||
|
||||
### Code
|
||||
Current implementation only supports inference.
|
||||
|
||||
### Languages
|
||||
As of now, XTTS-v1 supports 13 languages: English, Spanish, French, German, Italian, Portuguese,
|
||||
Polish, Turkish, Russian, Dutch, Czech, Arabic, and Chinese (Simplified).
|
||||
|
||||
Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out.
|
||||
|
||||
### License
|
||||
This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml).
|
||||
|
||||
### Contact
|
||||
Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
|
||||
You can also mail us at info@coqui.ai.
|
||||
|
||||
Using 🐸TTS API:
|
||||
|
||||
```python
|
||||
from TTS.api import TTS
|
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True)
|
||||
|
||||
# generate speech by cloning a voice using default settings
|
||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
file_path="output.wav",
|
||||
speaker_wav="/path/to/target/speaker.wav",
|
||||
language="en")
|
||||
|
||||
# generate speech by cloning a voice using custom settings
|
||||
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
file_path="output.wav",
|
||||
speaker_wav="/path/to/target/speaker.wav",
|
||||
language="en",
|
||||
decoder_iterations=30)
|
||||
```
|
||||
|
||||
Using 🐸TTS Command line:
|
||||
|
||||
```console
|
||||
tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
|
||||
--text "Bugün okula gitmek istemiyorum." \
|
||||
--speaker_wav /path/to/target/speaker.wav \
|
||||
--language_idx tr \
|
||||
--use_cuda true
|
||||
```
|
||||
|
||||
Using model directly:
|
||||
|
||||
```python
|
||||
from TTS.tts.configs.xtts_config import XttsConfig
|
||||
from TTS.tts.models.xtts import Xtts
|
||||
|
||||
config = XttsConfig()
|
||||
config.load_json("/path/to/xtts/config.json")
|
||||
model = Xtts.init_from_config(config)
|
||||
model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)
|
||||
model.cuda()
|
||||
|
||||
outputs = model.synthesize(
|
||||
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
|
||||
config,
|
||||
speaker_wav="/data/TTS-public/_refclips/3.wav",
|
||||
gpt_cond_len=3,
|
||||
language="en",
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Important resources & papers
|
||||
- VallE: https://arxiv.org/abs/2301.02111
|
||||
- Tortoise Repo: https://github.com/neonbjb/tortoise-tts
|
||||
- Faster implementation: https://github.com/152334H/tortoise-tts-fast
|
||||
- Univnet: https://arxiv.org/abs/2106.07889
|
||||
- Latent Diffusion:https://arxiv.org/abs/2112.10752
|
||||
- DALL-E: https://arxiv.org/abs/2102.12092
|
||||
|
||||
|
||||
## XttsConfig
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.configs.xtts_config.XttsConfig
|
||||
:members:
|
||||
```
|
||||
|
||||
## XttsArgs
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.xtts.XttsArgs
|
||||
:members:
|
||||
```
|
||||
|
||||
## XTTS Model
|
||||
```{eval-rst}
|
||||
.. autoclass:: TTS.tts.models.xtts.XTTS
|
||||
:members:
|
||||
```
|
|
@ -60,7 +60,7 @@ config = GlowTTSConfig(
|
|||
output_path=output_path,
|
||||
add_blank=True,
|
||||
datasets=[dataset_config],
|
||||
# characters=characters,
|
||||
# characters=characters,
|
||||
enable_eos_bos_chars=True,
|
||||
mixed_precision=False,
|
||||
save_step=10000,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import warnings
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes
|
||||
|
||||
|
@ -17,7 +17,8 @@ class TestText(unittest.TestCase):
|
|||
except KeyError:
|
||||
warnings.warn(
|
||||
"You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file to test Belarusian phonemizer",
|
||||
Warning)
|
||||
Warning,
|
||||
)
|
||||
return
|
||||
|
||||
for line in _TEST_CASES.strip().split("\n"):
|
||||
|
|
Loading…
Reference in New Issue