From 4033db5f4b8cbfdf906229df043758b439b6a2e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 8 Sep 2023 12:40:31 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20XTTS=20implementation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 5 +- TTS/.models.json | 14 +- TTS/api.py | 3 + TTS/bin/synthesize.py | 6 +- TTS/tts/configs/xtts_config.py | 90 ++ TTS/tts/layers/tortoise/tokenizer.py | 9 +- TTS/tts/layers/xtts/diffusion.py | 1344 +++++++++++++++++ TTS/tts/layers/xtts/dvae.py | 393 +++++ TTS/tts/layers/xtts/gpt.py | 545 +++++++ TTS/tts/layers/xtts/gpt_encoder_eren.py | 658 ++++++++ TTS/tts/layers/xtts/gpt_encoder_old.py | 1057 +++++++++++++ TTS/tts/layers/xtts/gpt_inference.py | 136 ++ TTS/tts/layers/xtts/latent_encoder.py | 141 ++ TTS/tts/layers/xtts/tokenizer.py | 286 ++++ TTS/tts/layers/xtts/vocoder.py | 385 +++++ TTS/tts/models/xtts.py | 654 ++++++++ TTS/tts/utils/text/belarusian/phonemizer.py | 5 +- TTS/tts/utils/text/phonemizers/__init__.py | 2 +- .../text/phonemizers/belarusian_phonemizer.py | 2 +- TTS/utils/generic_utils.py | 20 + TTS/utils/manage.py | 2 +- TTS/utils/synthesizer.py | 10 +- docs/source/index.md | 1 + docs/source/models/bark.md | 2 +- docs/source/models/tortoise.md | 2 +- docs/source/models/xtts.md | 108 ++ recipes/bel-alex73/train_glowtts.py | 2 +- .../text_tests/test_belarusian_phonemizer.py | 5 +- 28 files changed, 5866 insertions(+), 21 deletions(-) create mode 100644 TTS/tts/configs/xtts_config.py create mode 100644 TTS/tts/layers/xtts/diffusion.py create mode 100644 TTS/tts/layers/xtts/dvae.py create mode 100644 TTS/tts/layers/xtts/gpt.py create mode 100644 TTS/tts/layers/xtts/gpt_encoder_eren.py create mode 100644 TTS/tts/layers/xtts/gpt_encoder_old.py create mode 100644 TTS/tts/layers/xtts/gpt_inference.py create mode 100644 TTS/tts/layers/xtts/latent_encoder.py create mode 100644 TTS/tts/layers/xtts/tokenizer.py create mode 100644 TTS/tts/layers/xtts/vocoder.py create mode 100644 TTS/tts/models/xtts.py create mode 100644 docs/source/models/xtts.md diff --git a/README.md b/README.md index 6697a192..474f5499 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea - Delightful TTS: [paper](https://arxiv.org/abs/2110.12612) ### End-to-End Models +- ⓍTTS: [blog]() - VITS: [paper](https://arxiv.org/pdf/2106.06103) - 🐸 YourTTS: [paper](https://arxiv.org/abs/2112.02418) - 🐢 Tortoise: [orig. repo](https://github.com/neonbjb/tortoise-tts) @@ -248,11 +249,11 @@ tts.tts_with_vc_to_file( ``` #### Example using [🐸Coqui Studio](https://coqui.ai) voices. -You access all of your cloned voices and built-in speakers in [🐸Coqui Studio](https://coqui.ai). +You access all of your cloned voices and built-in speakers in [🐸Coqui Studio](https://coqui.ai). To do this, you'll need an API token, which you can obtain from the [account page](https://coqui.ai/account). After obtaining the API token, you'll need to configure the COQUI_STUDIO_TOKEN environment variable. -Once you have a valid API token in place, the studio speakers will be displayed as distinct models within the list. +Once you have a valid API token in place, the studio speakers will be displayed as distinct models within the list. These models will follow the naming convention `coqui_studio/en//coqui_studio` ```python diff --git a/TTS/.models.json b/TTS/.models.json index 69ac7514..07ef3902 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -2,6 +2,18 @@ "tts_models": { "multilingual": { "multi-dataset": { + "xtts_v1": { + "description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.", + "hf_url": [ + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json", + "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json" + ], + "default_vocoder": null, + "commit": "e9a1953e", + "license": "Coqui Community Model License", + "contact": "info@coqui.ai" + }, "your_tts": { "description": "Your TTS model accompanying the paper https://arxiv.org/abs/2112.02418", "github_rls_url": "https://coqui.gateway.scarf.sh/v0.10.1_models/tts_models--multilingual--multi-dataset--your_tts.zip", @@ -881,4 +893,4 @@ } } } -} +} \ No newline at end of file diff --git a/TTS/api.py b/TTS/api.py index 2ee108ba..1eb0b510 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -105,6 +105,9 @@ class TTS(nn.Module): @property def is_multi_lingual(self): + # TODO: fix this + if "xtts" in self.model_name: + return True if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: return self.synthesizer.tts_model.language_manager.num_languages > 1 return False diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 5ded3067..e8de18b0 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -392,7 +392,7 @@ If you don't specify any models, then it uses LJSpeech based English model. if args.encoder_path is not None: encoder_path = args.encoder_path encoder_config_path = args.encoder_config_path - + device = args.device if args.use_cuda: device = "cuda" @@ -459,7 +459,9 @@ If you don't specify any models, then it uses LJSpeech based English model. target_wav=args.target_wav, ) elif model_dir is not None: - wav = synthesizer.tts(args.text, speaker_name=args.speaker_idx) + wav = synthesizer.tts( + args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav + ) # save the results print(" > Saving output to {}".format(args.out_path)) diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py new file mode 100644 index 00000000..b9685590 --- /dev/null +++ b/TTS/tts/configs/xtts_config.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + + +@dataclass +class XttsConfig(BaseTTSConfig): + """Defines parameters for XTTS TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (XttsArgs): + Model architecture arguments. Defaults to `XttsArgs()`. + + audio (XttsAudioConfig): + Audio processing configuration. Defaults to `XttsAudioConfig()`. + + model_dir (str): + Path to the folder that has all the XTTS models. Defaults to None. + + temperature (float): + Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`. + + length_penalty (float): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, + which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), + length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. + + reperation_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`. + + top_p (float): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Defaults to `0.8`. + + cond_free_k (float): + Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`. + + diffusion_temperature (float): + Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + Defaults to `1.0`. + + num_gpt_outputs (int): + Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As XTTS is a probabilistic model, more samples means a higher probability of creating something "great". + Defaults to `16`. + + decoder_iterations (int): + Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. Defaults to `30`. + + decoder_sampler (str): + Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`. + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.xtts_config import XttsConfig + >>> config = XttsConfig() + """ + + model: str = "xtts" + # model specific params + model_args: XttsArgs = field(default_factory=XttsArgs) + audio: XttsAudioConfig = field(default_factory=XttsAudioConfig) + model_dir: str = None + languages: List[str] = field( + default_factory=lambda: ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn"] + ) + + # inference params + temperature: float = 0.2 + length_penalty: float = 1.0 + repetition_penalty: float = 2.0 + top_k: int = 50 + top_p: float = 0.8 + cond_free_k: float = 2.0 + diffusion_temperature: float = 1.0 + num_gpt_outputs: int = 16 + decoder_iterations: int = 30 + decoder_sampler: str = "ddim" diff --git a/TTS/tts/layers/tortoise/tokenizer.py b/TTS/tts/layers/tortoise/tokenizer.py index 3e544ee7..3969b2cc 100644 --- a/TTS/tts/layers/tortoise/tokenizer.py +++ b/TTS/tts/layers/tortoise/tokenizer.py @@ -5,15 +5,14 @@ from tokenizers import Tokenizer from TTS.tts.utils.text.cleaners import english_cleaners -DEFAULT_VOCAB_FILE = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json" -) - class VoiceBpeTokenizer: - def __init__(self, vocab_file=DEFAULT_VOCAB_FILE): + def __init__(self, vocab_file=None, vocab_str=None): + self.tokenizer = None if vocab_file is not None: self.tokenizer = Tokenizer.from_file(vocab_file) + if vocab_str is not None: + self.tokenizer = Tokenizer.from_str(vocab_str) def preprocess_text(self, txt): txt = english_cleaners(txt) diff --git a/TTS/tts/layers/xtts/diffusion.py b/TTS/tts/layers/xtts/diffusion.py new file mode 100644 index 00000000..a0b93add --- /dev/null +++ b/TTS/tts/layers/xtts/diffusion.py @@ -0,0 +1,1344 @@ +import enum +import math + +import numpy as np +import torch +import torch as th +from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral +from tqdm import tqdm + +from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper + +K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} +SAMPLERS = ["dpm++2m", "p", "ddim"] + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = "previous_x" # the model predicts x_{t-1} + START_X = "start_x" # the model predicts x_0 + EPSILON = "epsilon" # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = "learned" + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED_RANGE = "learned_range" + + +class LossType(enum.Enum): + MSE = "mse" # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances) + KL = "kl" # use the variational lower-bound + RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, # this is generally False + conditioning_free=False, + conditioning_free_k=1, + ramp_conditioning_free=True, + sampler="ddim", + ): + self.sampler = sampler + self.model_mean_type = ModelMeanType(model_mean_type) + self.model_var_type = ModelVarType(model_var_type) + self.loss_type = LossType(loss_type) + self.rescale_timesteps = rescale_timesteps + self.conditioning_free = conditioning_free + self.conditioning_free_k = conditioning_free_k + self.ramp_conditioning_free = ramp_conditioning_free + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + assert self.model_var_type == ModelVarType.LEARNED_RANGE + assert self.model_mean_type == ModelMeanType.EPSILON + assert denoised_fn is None + assert clip_denoised is True + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + if self.conditioning_free: + model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.conditioning_free: + model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + assert False + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + assert False + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + if self.conditioning_free: + if self.ramp_conditioning_free: + assert t.shape[0] == 1 # This should only be used in inference. + cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) + else: + cfk = self.conditioning_free_k + model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning + + def process_xstart(x): + if denoised_fn is not None: + assert False + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + assert False + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + assert False + pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + assert False + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def k_diffusion_sample_loop( + self, + k_sampler, + pbar, + model, + shape, + noise=None, # all given + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + device=None, # ALL UNUSED + model_kwargs=None, # {'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=False, # unused as well + ): + assert isinstance(model_kwargs, dict) + if device is None: + device = next(model.parameters()).device + s_in = noise.new_ones([noise.shape[0]]) + + def model_split(*args, **kwargs): + model_output = model(*args, **kwargs) + model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1) + return model_epsilon, model_var + + # + """ + print(self.betas) + print(th.tensor(self.betas)) + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) + """ + noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4) + + def model_fn_prewrap(x, t, *args, **kwargs): + """ + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + print(t) + print(self.timestep_map) + exit() + """ + """ + model_output = model(x, self._scale_timesteps(t*4000), **model_kwargs) + out = self.p_mean_variance(model, x, t*4000, model_kwargs=model_kwargs) + return out['pred_xstart'] + """ + x, _ = x.chunk(2) + t, _ = (t * 1000).chunk(2) + res = torch.cat( + [ + model_split(x, t, conditioning_free=True, **model_kwargs)[0], + model_split(x, t, **model_kwargs)[0], + ] + ) + pbar.update(1) + return res + + model_fn = model_wrapper( + model_fn_prewrap, + noise_schedule, + model_type="noise", # "noise" or "x_start" or "v" or "score" + model_kwargs=model_kwargs, + guidance_type="classifier-free", + condition=th.Tensor(1), + unconditional_condition=th.Tensor(1), + guidance_scale=self.conditioning_free_k, + ) + """ + model_fn = model_wrapper( + model_fn_prewrap, + noise_schedule, + model_type='x_start', + model_kwargs={} + ) + # + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + x_sample = dpm_solver.sample( + noise, + steps=20, + order=3, + skip_type="time_uniform", + method="singlestep", + ) + """ + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + x_sample = dpm_solver.sample( + noise, + steps=self.num_timesteps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + #''' + return x_sample + + # HF DIFFUSION ATTEMPT + """ + from .hf_diffusion import EulerAncestralDiscreteScheduler + Scheduler = EulerAncestralDiscreteScheduler() + Scheduler.set_timesteps(100) + for timestep in Scheduler.timesteps: + noise_input = Scheduler.scale_model_input(noise, timestep) + ts = s_in * timestep + model_output = model(noise_input, ts, **model_kwargs) + model_epsilon, _model_var = th.split(model_output, model_output.shape[1]//2, dim=1) + noise, _x0 = Scheduler.step(model_epsilon, timestep, noise) + return noise + """ + + # KARRAS DIFFUSION ATTEMPT + """ + TRAINED_DIFFUSION_STEPS = 4000 # HARDCODED + ratio = TRAINED_DIFFUSION_STEPS/14.5 + def call_model(*args, **kwargs): + model_output = model(*args, **kwargs) + model_output, model_var_values = th.split(model_output, model_output.shape[1]//2, dim=1) + return model_output + print(get_sigmas_karras(self.num_timesteps, sigma_min=0.0, sigma_max=4000, device=device)) + exit() + sigmas = get_sigmas_karras(self.num_timesteps, sigma_min=0.03, sigma_max=14.5, device=device) + return k_sampler(call_model, noise, sigmas, extra_args=model_kwargs, disable=not progress) + ''' + sigmas = get_sigmas_karras(self.num_timesteps, sigma_min=0.03, sigma_max=14.5, device=device) + step = 0 # LMAO + global_sigmas = None + # + def fakemodel(x, t, **model_kwargs): + print(t,global_sigmas*ratio) + return model(x, t, **model_kwargs) + def denoised(x, sigmas, **extra_args): + t = th.tensor([self.num_timesteps-step-1] * shape[0], device=device) + nonlocal global_sigmas + global_sigmas = sigmas + with th.no_grad(): + out = self.p_sample( + fakemodel, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + return out["sample"] + def callback(d): + nonlocal step + step += 1 + + return k_sampler(denoised, noise, sigmas, extra_args=model_kwargs, callback=callback, disable=not progress) + ''' + """ + + def sample_loop(self, *args, **kwargs): + s = self.sampler + if s == "p": + return self.p_sample_loop(*args, **kwargs) + elif s == "ddim": + return self.ddim_sample_loop(*args, **kwargs) + elif s == "dpm++2m": + if self.conditioning_free is not True: + raise RuntimeError("cond_free must be true") + with tqdm(total=self.num_timesteps) as pbar: + return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) + else: + raise RuntimeError("sampler not impl") + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + for i in tqdm(indices, disable=not progress): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = th.randn_like(x) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, disable=not progress) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + # TODO: support multiple model outputs for this mode. + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) + if isinstance(model_outputs, tuple): + model_output = model_outputs[0] + terms["extra_outputs"] = model_outputs[1:] + else: + model_output = model_outputs + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def autoregressive_training_losses( + self, + model, + x_start, + t, + model_output_keys, + gd_out_key, + model_kwargs=None, + noise=None, + ): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + assert False # not currently supported for this type of diffusion. + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) + terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) + model_output = terms[gd_out_key] + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C, 2, *x_t.shape[2:]) + model_output, model_var_values = ( + model_output[:, :, 0], + model_output[:, :, 1], + ) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model, autoregressive=False): + if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): + return model + mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel + return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) + + +class _WrappedAutoregressiveModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, x0, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, x0, new_ts, **kwargs) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py new file mode 100644 index 00000000..bdd7a9d0 --- /dev/null +++ b/TTS/tts/layers/xtts/dvae.py @@ -0,0 +1,393 @@ +import functools +from math import sqrt + +import torch +import torch.distributed as distributed +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from einops import rearrange + + +def default(val, d): + return val if val is not None else d + + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + + return inner + + +def dvae_wav_to_mel( + wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu") +): + mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=1024, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + norm="slaney", + ).to(device) + wav = wav.to(device) + mel = mel_stft(wav) + mel = torch.log(torch.clamp(mel, min=1e-5)) + if mel_norms is None: + mel_norms = torch.load(mel_norms_file, map_location=device) + mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + self.balancing_heuristic = balancing_heuristic + self.codes = None + self.max_codes = 64000 + self.codes_full = False + self.new_return_order = new_return_order + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, input, return_soft_codes=False): + if self.balancing_heuristic and self.codes_full: + h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes) + mask = torch.logical_or(h > 0.9, h < 0.01).unsqueeze(1) + ep = self.embed.permute(1, 0) + ea = self.embed_avg.permute(1, 0) + rand_embed = torch.randn_like(ep) * mask + self.embed = (ep * ~mask + rand_embed).permute(1, 0) + self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0) + self.cluster_size = self.cluster_size * ~mask.squeeze() + if torch.any(mask): + print(f"Reset {torch.sum(mask)} embedding codes.") + self.codes = None + self.codes_full = False + + flatten = input.reshape(-1, self.dim) + dist = flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + self.embed.pow(2).sum(0, keepdim=True) + soft_codes = -dist + _, embed_ind = soft_codes.max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.balancing_heuristic: + if self.codes is None: + self.codes = embed_ind.flatten() + else: + self.codes = torch.cat([self.codes, embed_ind.flatten()]) + if len(self.codes) > self.max_codes: + self.codes = self.codes[-self.max_codes :] + self.codes_full = True + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.all_reduce(embed_onehot_sum) + distributed.all_reduce(embed_sum) + + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + diff = (quantize.detach() - input).pow(2).mean() + quantize = input + (quantize - input).detach() + + if return_soft_codes: + return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,)) + elif self.new_return_order: + return quantize, embed_ind, diff + else: + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +# Fits a soft-discretized input to a normal-PDF across the specified dimension. +# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete +# values with the specified expected variance. +class DiscretizationLoss(nn.Module): + def __init__(self, discrete_bins, dim, expected_variance, store_past=0): + super().__init__() + self.discrete_bins = discrete_bins + self.dim = dim + self.dist = torch.distributions.Normal(0, scale=expected_variance) + if store_past > 0: + self.record_past = True + self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device="cpu")) + self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device="cpu")) + self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins)) + else: + self.record_past = False + + def forward(self, x): + other_dims = set(range(len(x.shape))) - set([self.dim]) + averaged = x.sum(dim=tuple(other_dims)) / x.sum() + averaged = averaged - averaged.mean() + + if self.record_past: + acc_count = self.accumulator.shape[0] + avg = averaged.detach().clone() + if self.accumulator_filled > 0: + averaged = torch.mean(self.accumulator, dim=0) * (acc_count - 1) / acc_count + averaged / acc_count + + # Also push averaged into the accumulator. + self.accumulator[self.accumulator_index] = avg + self.accumulator_index += 1 + if self.accumulator_index >= acc_count: + self.accumulator_index *= 0 + if self.accumulator_filled <= 0: + self.accumulator_filled += 1 + + return torch.sum(-self.dist.log_prob(averaged)) + + +class ResBlock(nn.Module): + def __init__(self, chan, conv, activation): + super().__init__() + self.net = nn.Sequential( + conv(chan, chan, 3, padding=1), + activation(), + conv(chan, chan, 3, padding=1), + activation(), + conv(chan, chan, 1), + ) + + def forward(self, x): + return self.net(x) + x + + +class UpsampledConv(nn.Module): + def __init__(self, conv, *args, **kwargs): + super().__init__() + assert "stride" in kwargs.keys() + self.stride = kwargs["stride"] + del kwargs["stride"] + self.conv = conv(*args, **kwargs) + + def forward(self, x): + up = nn.functional.interpolate(x, scale_factor=self.stride, mode="nearest") + return self.conv(up) + + +# DiscreteVAE partially derived from lucidrains DALLE implementation +# Credit: https://github.com/lucidrains/DALLE-pytorch +class DiscreteVAE(nn.Module): + def __init__( + self, + positional_dims=2, + num_tokens=512, + codebook_dim=512, + num_layers=3, + num_resnet_blocks=0, + hidden_dim=64, + channels=3, + stride=2, + kernel_size=4, + use_transposed_convs=True, + encoder_norm=False, + activation="relu", + smooth_l1_loss=False, + straight_through=False, + normalization=None, # ((0.5,) * 3, (0.5,) * 3), + record_codes=False, + discretization_loss_averaging_steps=100, + lr_quantizer_args={}, + ): + super().__init__() + has_resblocks = num_resnet_blocks > 0 + + self.num_tokens = num_tokens + self.num_layers = num_layers + self.straight_through = straight_through + self.positional_dims = positional_dims + self.discrete_loss = DiscretizationLoss( + num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps + ) + + assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. + if positional_dims == 2: + conv = nn.Conv2d + conv_transpose = nn.ConvTranspose2d + else: + conv = nn.Conv1d + conv_transpose = nn.ConvTranspose1d + if not use_transposed_convs: + conv_transpose = functools.partial(UpsampledConv, conv) + + if activation == "relu": + act = nn.ReLU + elif activation == "silu": + act = nn.SiLU + else: + assert NotImplementedError() + + enc_layers = [] + dec_layers = [] + + if num_layers > 0: + enc_chans = [hidden_dim * 2**i for i in range(num_layers)] + dec_chans = list(reversed(enc_chans)) + + enc_chans = [channels, *enc_chans] + + dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] + dec_chans = [dec_init_chan, *dec_chans] + + enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + + pad = (kernel_size - 1) // 2 + for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act())) + if encoder_norm: + enc_layers.append(nn.GroupNorm(8, enc_out)) + dec_layers.append( + nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride=stride, padding=pad), act()) + ) + dec_out_chans = dec_chans[-1] + innermost_dim = dec_chans[0] + else: + enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act())) + dec_out_chans = hidden_dim + innermost_dim = hidden_dim + + for _ in range(num_resnet_blocks): + dec_layers.insert(0, ResBlock(innermost_dim, conv, act)) + enc_layers.append(ResBlock(innermost_dim, conv, act)) + + if num_resnet_blocks > 0: + dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) + + enc_layers.append(conv(innermost_dim, codebook_dim, 1)) + dec_layers.append(conv(dec_out_chans, channels, 1)) + + self.encoder = nn.Sequential(*enc_layers) + self.decoder = nn.Sequential(*dec_layers) + + self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss + self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True) + + # take care of normalization within class + self.normalization = normalization + self.record_codes = record_codes + if record_codes: + self.codes = torch.zeros((1228800,), dtype=torch.long) + self.code_ind = 0 + self.total_codes = 0 + self.internal_step = 0 + + def norm(self, images): + if not self.normalization is not None: + return images + + means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization) + arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()" + means, stds = map(lambda t: rearrange(t, arrange), (means, stds)) + images = images.clone() + images.sub_(means).div_(stds) + return images + + def get_debug_values(self, step, __): + if self.record_codes and self.total_codes > 0: + # Report annealing schedule + return {"histogram_codes": self.codes[: self.total_codes]} + else: + return {} + + @torch.no_grad() + @eval_decorator + def get_codebook_indices(self, images): + img = self.norm(images) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, _ = self.codebook(logits) + self.log_codes(codes) + return codes + + def decode(self, img_seq): + self.log_codes(img_seq) + if hasattr(self.codebook, "embed_code"): + image_embeds = self.codebook.embed_code(img_seq) + else: + image_embeds = F.embedding(img_seq, self.codebook.codebook) + b, n, d = image_embeds.shape + + kwargs = {} + if self.positional_dims == 1: + arrange = "b n d -> b d n" + else: + h = w = int(sqrt(n)) + arrange = "b (h w) d -> b d h w" + kwargs = {"h": h, "w": w} + image_embeds = rearrange(image_embeds, arrange, **kwargs) + images = [image_embeds] + for layer in self.decoder: + images.append(layer(images[-1])) + return images[-1], images[-2] + + def infer(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, commitment_loss = self.codebook(logits) + return self.decode(codes) + + # Note: This module is not meant to be run in forward() except while training. It has special logic which performs + # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially + # more lossy (but useful for determining network performance). + def forward(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, commitment_loss = self.codebook(logits) + sampled = sampled.permute((0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1)) + + if self.training: + out = sampled + for d in self.decoder: + out = d(out) + self.log_codes(codes) + else: + # This is non-differentiable, but gives a better idea of how the network is actually performing. + out, _ = self.decode(codes) + + # reconstruction loss + recon_loss = self.loss_fn(img, out, reduction="none") + + return recon_loss, commitment_loss, out + + def log_codes(self, codes): + # This is so we can debug the distribution of codes being learned. + if self.record_codes and self.internal_step % 10 == 0: + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i : i + l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.total_codes += 1 + self.internal_step += 1 diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py new file mode 100644 index 00000000..2a821a5d --- /dev/null +++ b/TTS/tts/layers/xtts/gpt.py @@ -0,0 +1,545 @@ +# ported from: https://github.com/neonbjb/tortoise-tts + +import functools +import math +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import GPT2Config + +from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel +from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=0.02, relative=False): + super().__init__() + # nn.Embedding + self.emb = torch.nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + self.relative = relative + self.seq_len = seq_len + + def forward(self, x): + sl = x.shape[1] + if self.relative: + start = random.randint(sl, self.seq_len) - sl + return self.emb(torch.arange(start, start + sl, device=x.device)) + else: + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + + +def build_hf_gpt_transformer( + layers, + model_dim, + heads, + max_mel_seq_len, + max_text_seq_len, + max_prompt_len, + checkpointing, +): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing, + ) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + + mel_pos_emb = ( + LearnedPositionEmbeddings(max_mel_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + text_pos_emb = ( + LearnedPositionEmbeddings(max_text_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True) + return gpt, mel_pos_emb, text_pos_emb, None, None + + +class GPT(nn.Module): + def __init__( + self, + start_text_token=261, + stop_text_token=0, + layers=8, + model_dim=512, + heads=8, + max_text_tokens=120, + max_mel_tokens=250, + max_prompt_tokens=70, + max_conditioning_inputs=1, + code_stride_len=1024, + number_text_tokens=256, + num_audio_tokens=8194, + start_audio_token=8192, + stop_audio_token=8193, + train_solo_embeddings=False, + checkpointing=False, + average_conditioning_embeddings=False, + label_smoothing=0.0, + ): + """ + Args: + + """ + super().__init__() + + self.label_smoothing = label_smoothing + self.number_text_tokens = number_text_tokens + self.start_text_token = start_text_token + self.stop_text_token = stop_text_token + self.num_audio_tokens = num_audio_tokens + self.start_audio_token = start_audio_token + self.stop_audio_token = stop_audio_token + self.start_prompt_token = start_audio_token + self.stop_prompt_token = stop_audio_token + self.layers = layers + self.heads = heads + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs + self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 + self.max_prompt_tokens = max_prompt_tokens + self.code_stride_len = code_stride_len + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.conditioning_dropout = nn.Dropout1d(0.1) + self.average_conditioning_embeddings = average_conditioning_embeddings + + self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) + self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim) + + self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim) + self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim) + + ( + self.gpt, + self.mel_pos_embedding, + self.text_pos_embedding, + self.mel_layer_pos_embedding, + self.text_layer_pos_embedding, + ) = build_hf_gpt_transformer( + layers, + model_dim, + heads, + self.max_mel_tokens, + self.max_text_tokens, + self.max_prompt_tokens, + checkpointing, + ) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens) + self.mel_head = nn.Linear(model_dim, self.num_audio_tokens) + + def get_grad_norm_parameter_groups(self): + return { + "conditioning_encoder": list(self.conditioning_encoder.parameters()), + "gpt": list(self.gpt.parameters()), + "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), + } + + def init_gpt_for_inference(self, kv_cache=True): + seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.gpt_inference = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + self.gpt.wte = self.mel_embedding + + def set_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, code_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with stop_audio_token in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + # Set padding areas within MEL (currently it is coded with the MEL code for ). + for b in range(len(code_lengths)): + actual_end = code_lengths[b] + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_audio_token + return mel_input_tokens + + def get_logits( + self, + first_inputs, + first_head, + second_inputs=None, + second_head=None, + prompt=None, + get_attns=False, + return_latent=False, + attn_mask_text=None, + attn_mask_mel=None, + ): + if prompt is not None: + offset = prompt.shape[1] + if second_inputs is not None: + emb = torch.cat([prompt, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([prompt, first_inputs], dim=1) + + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + attn_mask = None + if attn_mask_text is not None: + attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) + if prompt is not None: + attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) + attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1) + + gpt_out = self.gpt( + inputs_embeds=emb, + return_dict=True, + output_attentions=get_attns, + attention_mask=attn_mask, + ) + + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, offset:] + enc = self.final_norm(enc) + + if return_latent: + return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :] + + first_logits = enc[:, : first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1] :] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def get_prompts(self, prompt_codes): + """ + Create a prompt from the mel codes. This is used to condition the model on the mel codes. + Pad the prompt with start and stop mel tokens. + """ + prompt = prompt_codes + if self.training: + lengths = [] + # Compute the real prompt length based on the first encounter with the token 83 used for padding + for i in range(prompt_codes.shape[0]): + length = 0 + for j in range(prompt_codes.shape[1]): + if prompt_codes[i, j] == 83: + break + else: + length += 1 + lengths.append(length) + + # prompt_len = random.randint(1, 9) # in secs + prompt_len = 3 + prompt_len = prompt_len * 24 # in frames + if prompt_codes.shape[-1] >= prompt_len: + new_prompt = [] + for i in range(prompt_codes.shape[0]): + if lengths[i] < prompt_len: + start = 0 + else: + start = random.randint(0, lengths[i] - prompt_len) + prompt = prompt_codes[:, start : start + prompt_len] + + # add start and stop tokens + prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token) + prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) + return prompt + + def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True): + """ + cond_input: (b, 80, s) or (b, 1, 80, s) + conds: (b, 1024, s) + """ + conds = None + if not return_latent: + if cond_input.ndim == 4: + cond_input = cond_input.squeeze(1) + if sample: + _len_secs = random.randint(2, 6) # in secs + cond_seg_len = int((22050 / 1024) * _len_secs) # in frames + if cond_input.shape[-1] >= cond_seg_len: + new_conds = [] + for i in range(cond_input.shape[0]): + cond_len = int(cond_lens[i] / 1024) + if cond_len < cond_seg_len: + start = 0 + else: + start = random.randint(0, cond_len - cond_seg_len) + cond_vec = cond_input[i, :, start : start + cond_seg_len] + new_conds.append(cond_vec) + conds = torch.stack(new_conds, dim=0) + else: + cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs + cond_frame_len = int((22050 / 1024) * cond_seg_len) + conds = cond_input[:, :, -cond_frame_len:] + + conds = self.conditioning_encoder(conds) + else: + # already computed + conds = cond_input.unsqueeze(1) + return conds + + def forward( + self, + text_inputs, + text_lengths, + audio_codes, + wav_lengths, + cond_lens=None, + cond_mels=None, + cond_latents=None, + loss_weights=None, + return_attentions=False, + return_latent=False, + ): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + cond_mels: MEL float tensor, (b, 1, 80,s) + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + """ + # ❗ FIXIT + if self.max_conditioning_inputs == 0: + assert cond_mels is None, " ❗ cond_mels is not None, but max_conditioning_inputs == 0" + + max_text_len = text_lengths.max() + code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3 + + # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. + max_mel_len = code_lengths.max() + + if max_mel_len > audio_codes.shape[-1]: + audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1])) + + silence = True + for idx, l in enumerate(code_lengths): + length = l.item() + while silence: + if audio_codes[idx, length - 1] != 83: + break + length -= 1 + code_lengths[idx] = length + + # 💖 Lovely assertions + assert ( + max_mel_len <= audio_codes.shape[-1] + ), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})" + assert ( + max_text_len <= text_inputs.shape[-1] + ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})" + + # Append stop token to text inputs + text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) + + # Append silence token to mel codes + audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token) + + # Pad mel codes with stop_audio_token + audio_codes = self.set_mel_padding(audio_codes, code_lengths) + + # Build input and target tensors + # Prepend start token to inputs and append stop token to targets + text_inputs, text_targets = self.set_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + audio_codes, mel_targets = self.set_inputs_and_targets( + audio_codes, self.start_audio_token, self.stop_audio_token + ) + + # Set attn_mask + attn_mask_text = None + attn_mask_mel = None + if not return_latent: + attn_mask_text = torch.ones( + text_inputs.shape[0], + text_inputs.shape[1], + dtype=torch.bool, + device=text_inputs.device, + ) + attn_mask_mel = torch.ones( + audio_codes.shape[0], + audio_codes.shape[1], + dtype=torch.bool, + device=audio_codes.device, + ) + + for idx, l in enumerate(text_lengths): + attn_mask_text[idx, l + 1 :] = 0.0 + + for idx, l in enumerate(code_lengths): + attn_mask_mel[idx, l + 1 :] = 0.0 + + # Compute text embeddings + positional embeddings + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + # Compute mel embeddings + positional embeddings + mel_emb = self.mel_embedding(audio_codes) + self.mel_pos_embedding(audio_codes) + + # Compute speech conditioning input + if cond_latents is None: + cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2) + + # Get logits + sub = -5 # don't ask me why 😄 + if self.training: + sub = -1 + + text_logits, mel_logits = self.get_logits( + text_emb, + self.text_head, + mel_emb, + self.mel_head, + prompt=cond_latents, + get_attns=return_attentions, + return_latent=return_latent, + attn_mask_text=attn_mask_text, + attn_mask_mel=attn_mask_mel, + ) + if return_latent: + return mel_logits[:, :sub] # sub to prevent bla. + + if return_attentions: + return mel_logits + + # Set paddings to -1 to ignore them in loss + for idx, l in enumerate(text_lengths): + text_targets[idx, l + 1 :] = -1 + + for idx, l in enumerate(code_lengths): + mel_targets[idx, l + 1 :] = -1 + + # check if stoptoken is in every row of mel_targets + assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[ + 0 + ], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row." + + # Compute losses + loss_text = F.cross_entropy( + text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing + ) + loss_mel = F.cross_entropy( + mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing + ) + return loss_text.mean(), loss_mel.mean(), mel_logits + + def inference(self, cond_latents, text_inputs, **hf_generate_kwargs): + self.compute_embeddings(cond_latents, text_inputs) + return self.generate(cond_latents, text_inputs, input_tokens=None, **hf_generate_kwargs) + + def compute_embeddings( + self, + cond_latents, + text_inputs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + emb = torch.cat([cond_latents, emb], dim=1) + self.gpt_inference.store_prefix_emb(emb) + gpt_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_audio_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + gpt_inputs[:, -1] = self.start_audio_token + return gpt_inputs + + def generate( + self, + cond_latents, + text_inputs, + **hf_generate_kwargs, + ): + gpt_inputs = self.compute_embeddings(cond_latents, text_inputs) + gen = self.gpt_inference.generate( + gpt_inputs, + bos_token_id=self.start_audio_token, + pad_token_id=self.stop_audio_token, + eos_token_id=self.stop_audio_token, + max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, + **hf_generate_kwargs, + ) + if "return_dict_in_generate" in hf_generate_kwargs: + return gen.sequences[:, gpt_inputs.shape[1] :], gen + return gen[:, gpt_inputs.shape[1] :] diff --git a/TTS/tts/layers/xtts/gpt_encoder_eren.py b/TTS/tts/layers/xtts/gpt_encoder_eren.py new file mode 100644 index 00000000..b5e7158d --- /dev/null +++ b/TTS/tts/layers/xtts/gpt_encoder_eren.py @@ -0,0 +1,658 @@ +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +class GPT2InferenceModel(GPT2PreTrainedModel): + """Override GPT2LMHeadModel to allow for prefix conditioning.""" + + def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): + super().__init__(config) + self.transformer = gpt + self.pos_embedding = pos_emb + self.embeddings = embeddings + self.final_norm = norm + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + def store_prefix_emb(self, prefix_emb): + self.cached_prefix_emb = prefix_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values is not None: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_prefix_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1] + + # Create embedding + prefix_len = self.cached_prefix_emb.shape[1] + if input_ids.shape[1] != 1: + gen_inputs = input_ids[:, prefix_len:] + gen_emb = self.embeddings(gen_inputs) + gen_emb = gen_emb + self.pos_embedding(gen_emb) + if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]: + prefix_emb = self.cached_prefix_emb.repeat_interleave( + gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0 + ) + else: + prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype) + emb = torch.cat([prefix_emb, gen_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - (prefix_len + 1), attention_mask.device + ) + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_channels, init_std=0.02, relative=False): + super().__init__() + self.emb = nn.Embedding(seq_len, model_channels) + nn.init.normal_(self.emb.weight, mean=0.0, std=init_std) + self.relative = relative + + def forward(self, x): + seq_len = x.shape[1] + if self.relative: + start = torch.randint(seq_len, (1,), device=x.device).item() + positions = torch.arange(start, start + seq_len, device=x.device) + else: + positions = torch.arange(seq_len, device=x.device) + return self.emb(positions) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + + +def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing): + """ + Initializes a GPT-2 model and its position embeddings for a text-to-speech system. + + Args: + layers (int): Number of layers in the GPT-2 model. + model_channels (int): Dimension of the GPT-2 model. + heads (int): Number of heads in the GPT-2 model. + max_mel_seq_len (int): Maximum sequence length for the mel spectrogram. + max_text_seq_len (int): Maximum sequence length for the text. + max_prompt_len (int): Maximum length of the prompt. + checkpointing (bool): Whether to use gradient checkpointing. + + Returns: + gpt (GPT2Model): GPT-2 model. + mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram. + text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text. + """ + gpt_config = GPT2Config( + vocab_size=123, + n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_embd=model_channels, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing, + ) + gpt = GPT2Model(gpt_config) + + del gpt.wpe + del gpt.wte + + gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels) + + audio_pos_emb = ( + LearnedPositionEmbeddings(max_mel_seq_len, model_channels) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_channels) + ) + text_pos_emb = ( + LearnedPositionEmbeddings(max_text_seq_len, model_channels) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_channels) + ) + + return gpt, audio_pos_emb, text_pos_emb + + +class XTTSGPTEncoder(nn.Module): + """XTTS GPT Encoder model implementation. + Args: + start_text_token (int): Index of the start token in the text vocabulary. + stop_text_token (int): Index of the stop token in the text vocabulary. + n_layers (int): Number of layers in the GPT-2 model. + n_model_channels (int): Dimension of the GPT-2 model. + n_heads (int): Number of heads in the GPT-2 model. + max_text_tokens (int): Maximum number of text tokens. + max_audio_tokens (int): Maximum number of audio tokens. + max_prompt_tokens (int): Maximum number of prompt tokens. + audio_len_compression (int): Compression factor for the audio length. + number_text_tokens (int): Number of text tokens. + number_audio_codes (int): Number of audio codes. + start_mel_token (int): Index of the start token in the mel code vocabulary. + stop_mel_token (int): Index of the stop token in the mel code vocabulary. + checkpointing (bool): Whether or not to use gradient checkpointing at training. + """ + + _inference_flag = False + + def __init__( + self, + start_text_token=261, + stop_text_token=0, + n_layers=8, + n_model_channels=512, + n_heads=8, + max_text_tokens=120, + max_audio_tokens=250, + max_prompt_tokens=70, + audio_len_compression=1024, + number_text_tokens=256, + number_audio_codes=8194, + start_mel_token=8192, + stop_mel_token=8193, + checkpointing=True, + label_smoothing=0.0, + ): + super().__init__() + + self.label_smoothing = label_smoothing + self.number_text_tokens = number_text_tokens + self.start_text_token = start_text_token + self.stop_text_token = stop_text_token + self.number_audio_codes = number_audio_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token + self.start_prompt_token = start_mel_token + self.stop_prompt_token = stop_mel_token + self.n_layers = n_layers + self.n_heads = n_heads + self.n_model_channels = n_model_channels + self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs + self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 + self.max_prompt_tokens = max_prompt_tokens + self.audio_len_compression = audio_len_compression + + # embedding layers + self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels) + self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels) + self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels) + self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels) + + # initialize the GPT-2 model + ( + self.gpt, + self.audio_pos_embedding, + self.text_pos_embedding, + ) = init_gpt( + n_layers, + n_model_channels, + n_heads, + self.max_audio_tokens, + self.max_text_tokens, + self.max_prompt_tokens, + checkpointing, + ) + + # output layers + self.final_norm = nn.LayerNorm(n_model_channels) + self.text_head = nn.Linear(n_model_channels, self.number_text_tokens) + self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes) + + def get_grad_norm_parameter_groups(self): + return { + "conditioning_encoder": list(self.conditioning_encoder.parameters()), + "gpt": list(self.gpt.parameters()), + "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), + } + + def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False): + self._inference_flag = True + seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens + gpt_config = GPT2Config( + vocab_size=self.max_audio_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.n_model_channels, + n_layer=self.n_layers, + n_head=self.n_heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.audio_pos_embedding, + self.audio_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + self.gpt.wte = self.audio_embedding + + def set_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_audio_tokens_padding(self, audio_tokens, audio_token_lens): + # Set padding areas within MEL (currently it is coded with the MEL code for ). + for b in range(len(audio_token_lens)): + actual_end = audio_token_lens[b] + if actual_end < audio_tokens.shape[-1]: + audio_tokens[b, actual_end:] = self.stop_mel_token + return audio_tokens + + def get_logits( + self, + speech_conditioning_inputs, + first_inputs, + first_head, + second_inputs=None, + second_head=None, + prompt=None, + get_attns=False, + return_latent=False, + attn_mask_text=None, + attn_mask_mel=None, + ): + if prompt is not None and speech_conditioning_inputs is not None: + offset = speech_conditioning_inputs.shape[1] + prompt.shape[1] + if second_inputs is not None: + emb = torch.cat( + [speech_conditioning_inputs, prompt, first_inputs, second_inputs], + dim=1, + ) + else: + emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1) + elif speech_conditioning_inputs is not None: + offset = speech_conditioning_inputs.shape[1] + if second_inputs is not None: + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) + elif prompt is not None: + offset = prompt.shape[1] + if second_inputs is not None: + emb = torch.cat([prompt, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([prompt, first_inputs], dim=1) + + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + attn_mask = None + if attn_mask_text is not None: + attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) + if prompt is not None: + attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) + attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1) + + gpt_out = self.gpt( + inputs_embeds=emb, + return_dict=True, + output_attentions=get_attns, + attention_mask=attn_mask, + ) + + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, offset:] + enc = self.final_norm(enc) + + if return_latent: + return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :] + + first_logits = enc[:, : first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1] :] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def get_prompts(self, prompt_codes): + prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token) + prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token) + return prompt + + def forward( + self, + text_inputs, + text_lengths, + audio_codes, + wav_lengths, + prompt_codes, + return_attentions=False, + return_latent=False, + ): + max_text_len = text_lengths.max() + + # Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values + # Like [..., 186, 45, 45, 83] where actually it should end with 186. + # We take last 3 codes to prevent abrupt ending of the audio. + # TODO: This is might need some testing. + mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3 + + # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. + max_mel_len = mel_lengths.max() + + if max_mel_len > audio_codes.shape[-1]: + audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1])) + + # silence aware lengths, skip the silence tokens at the end of the mel codes. + silence = True + for idx, l in enumerate(mel_lengths): + length = l.item() + while silence: + if audio_codes[idx, length - 1] != 83: + break + length -= 1 + mel_lengths[idx] = length + + # Lovely assertions + assert ( + max_mel_len <= audio_codes.shape[-1] + ), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})" + assert ( + max_text_len <= text_inputs.shape[-1] + ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})" + + # Append stop token to text inputs + text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) + + # Append silence token to mel codes + audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token) + + # Pad mel codes with STOP_MEL_TOKEN + audio_codes = self.set_mel_padding(audio_codes, mel_lengths) + + # Compute speech conditioning input + conds = None + if speech_conditioning_input is not None: + if not return_latent: + # Compute speech conditioning input + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) + else: + # already computed + conds = speech_conditioning_input.unsqueeze(1) + + # Build input and target tensors + # Prepend start token to inputs and append stop token to targets + text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) + audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token) + + # Set attn_mask + attn_mask_text = None + attn_mask_mel = None + if not return_latent: + attn_mask_text = torch.ones( + text_inputs.shape[0], + text_inputs.shape[1], + dtype=torch.bool, + device=text_inputs.device, + ) + attn_mask_mel = torch.ones( + audio_codes.shape[0], + audio_codes.shape[1], + dtype=torch.bool, + device=audio_codes.device, + ) + + for idx, l in enumerate(text_lengths): + attn_mask_text[idx, l + 1 :] = 0.0 + + for idx, l in enumerate(mel_lengths): + attn_mask_mel[idx, l + 1 :] = 0.0 + + # Compute text embeddings + positional embeddings + # print(" > text input latent:", text_inputs) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + # Compute mel embeddings + positional embeddings + audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes) + + # Compute prompt embeddings + positional embeddings + prompt = self.get_prompts(prompt_codes) + + # prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach() + prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt) + + # dropout prompt embeddings + prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training) + + # Get logits + sub = -4 # don't ask me why 😄 + if self.training: + sub = -1 + _, audio_logits = self.get_logits( + conds, + text_emb, + self.text_head, + audio_emb, + self.mel_head, + prompt=prompt_emb, + get_attns=return_attentions, + return_latent=return_latent, + attn_mask_text=attn_mask_text, + attn_mask_mel=attn_mask_mel, + ) + return audio_logits[:, :sub] # sub to prevent bla. + + def compute_embeddings( + self, + speech_conditioning_latent, + text_inputs, + input_tokens=None, + prompt_codes=None, + pad_input_text=False, + ): + """Compute all the embeddings needed for inference.""" + if pad_input_text and text_inputs.shape[1] < 250: + text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token) + else: + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + print(" > Text inputs:", text_inputs) + if prompt_codes is not None: + prompt_codes = self.get_prompts(prompt_codes) + # prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes) + prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes) + + print(" > Prompt inputs:", prompt_codes) + print(" > Prompt inputs shape:", prompt_codes.shape) + emb = torch.cat([prompt_emb, emb], dim=1) + + if speech_conditioning_latent is not None: + conds = speech_conditioning_latent.unsqueeze(1) + emb = torch.cat([conds, emb], dim=1) + + self.inference_model.store_prefix_emb(emb) + + fake_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_mel_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + fake_inputs[:, -1] = self.start_mel_token + + if input_tokens is not None: + fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1) + return fake_inputs + + def inference( + self, + text_inputs, + input_tokens=None, + prompt_codes=None, + pad_input_text=False, + **hf_generate_kwargs, + ): + if pad_input_text and text_inputs.shape[1] < 250: + text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token) + else: + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + if prompt_codes is not None: + prompt_codes = self.get_prompts(prompt_codes) + prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes) + emb = torch.cat([prompt_emb, emb], dim=1) + + self.inference_model.store_prefix_emb(emb) + + fake_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_mel_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + fake_inputs[:, -1] = self.start_mel_token + + if input_tokens is not None: + fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1) + + gen = self.inference_model.generate( + fake_inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, + **hf_generate_kwargs, + ) + if "return_dict_in_generate" in hf_generate_kwargs: + return gen.sequences[:, fake_inputs.shape[1] :], gen + return gen[:, fake_inputs.shape[1] :] diff --git a/TTS/tts/layers/xtts/gpt_encoder_old.py b/TTS/tts/layers/xtts/gpt_encoder_old.py new file mode 100644 index 00000000..46739aa2 --- /dev/null +++ b/TTS/tts/layers/xtts/gpt_encoder_old.py @@ -0,0 +1,1057 @@ +import functools +import math +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + import deepspeed + from deepspeed.ops.transformer.inference import DeepSpeedTransformerInferenceKernel +except ImportError: + pass + +import dlas.codes.torch_intermediary as ml +from dlas.codes.models.arch_util import AttentionBlock +from dlas.codes.trainer.networks import register_model +from dlas.codes.utils.transformers.stream_generator import init_stream_support +from dlas.codes.utils.util import opt_get +from transformers import GPT2Config, GPT2PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + +init_stream_support() + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +class ResBlock(nn.Module): + """ + Basic residual convolutional block that uses GroupNorm. + """ + + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + +class GPT2InferenceModel(GPT2PreTrainedModel): + """Override GPT2LMHeadModel to allow for prefix conditioning.""" + + def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): + super().__init__(config) + self.transformer = gpt + self.pos_embedding = pos_emb + self.embeddings = embeddings + self.final_norm = norm + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + def store_prefix_emb(self, prefix_emb): + self.cached_prefix_emb = prefix_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values is not None: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_prefix_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1] + + # Create embedding + prefix_len = self.cached_prefix_emb.shape[1] + if input_ids.shape[1] != 1: + gen_inputs = input_ids[:, prefix_len:] + gen_emb = self.embeddings(gen_inputs) + gen_emb = gen_emb + self.pos_embedding(gen_emb) + if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]: + prefix_emb = self.cached_prefix_emb.repeat_interleave( + gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0 + ) + else: + prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype) + emb = torch.cat([prefix_emb, gen_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - (prefix_len + 1), attention_mask.device + ) + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +class ConditioningEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + do_checkpointing=False, + mean=False, + ): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + self.mean = mean + + def forward(self, x): + h = self.init(x) + h = self.attn(h) + if self.mean: + return h.mean(dim=2) + else: + return h[:, :, 0] + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=0.02, relative=False): + super().__init__() + # nn.Embedding + self.emb = torch.nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + self.relative = relative + self.seq_len = seq_len + + def forward(self, x): + sl = x.shape[1] + if self.relative: + start = random.randint(sl, self.seq_len) - sl + return self.emb(torch.arange(start, start + sl, device=x.device)) + else: + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + + +def build_hf_gpt_transformer( + layers, + model_dim, + heads, + max_mel_seq_len, + max_text_seq_len, + max_prompt_len, + checkpointing, +): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing, + ) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + + # def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # attn_output = torch.nn.functional.scaled_dot_product_attention( + # query, key, value, dropout_p=self.attn_dropout.p, is_causal=True + # ) + # return attn_output, None + + # for i in range(len(gpt.h)): + # gpt.h[i].attn._attn = types.MethodType( + # _attn, gpt.h[i].attn + # ) + + mel_pos_emb = ( + LearnedPositionEmbeddings(max_mel_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + text_pos_emb = ( + LearnedPositionEmbeddings(max_text_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True) + return gpt, mel_pos_emb, text_pos_emb, None, None + + +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential( + nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 16, channels // 2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 4 + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x.permute(0, 2, 1) + + +class UnifiedVoice(nn.Module): + def __init__( + self, + start_text_token=261, + stop_text_token=0, + layers=8, + model_dim=512, + heads=8, + max_text_tokens=120, + max_mel_tokens=250, + max_prompt_tokens=70, + max_conditioning_inputs=1, + mel_length_compression=1024, + number_text_tokens=256, + number_mel_codes=8194, + start_mel_token=8192, + stop_mel_token=8193, + train_solo_embeddings=False, + use_mel_codes_as_input=True, + checkpointing=True, + average_conditioning_embeddings=False, + freeze_everything_but_position_embeddings=False, + freeze_conditioning_encoder=False, + tortoise_compat=True, + label_smoothing=0.0, + ): + """ + Args: + layers: Number of layers in transformer stack. + model_dim: Operating dimensions of the transformer + heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 + max_text_tokens: Maximum number of text tokens that will be encountered by model. + max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. + max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). + mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. + number_text_tokens: + start_text_token: + stop_text_token: + number_mel_codes: + start_mel_token: + stop_mel_token: + train_solo_embeddings: + use_mel_codes_as_input: + checkpointing: + average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model. + """ + super().__init__() + + self.label_smoothing = label_smoothing + self.number_text_tokens = number_text_tokens + self.start_text_token = start_text_token + self.stop_text_token = stop_text_token + self.number_mel_codes = number_mel_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token + self.start_prompt_token = start_mel_token + self.stop_prompt_token = stop_mel_token + self.layers = layers + self.heads = heads + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs + self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 + self.max_prompt_tokens = max_prompt_tokens + self.mel_length_compression = mel_length_compression + # self.conditioning_encoder = ConditioningEncoder( + # 80, model_dim, num_attn_heads=heads + # ) + self.average_conditioning_embeddings = average_conditioning_embeddings + self.tortoise_compat = tortoise_compat # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b + # nn.Embedding + self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim) + if use_mel_codes_as_input: + # nn.Embedding + self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim) + else: + self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + ( + self.gpt, + self.mel_pos_embedding, + self.text_pos_embedding, + self.mel_layer_pos_embedding, + self.text_layer_pos_embedding, + ) = build_hf_gpt_transformer( + layers, + model_dim, + heads, + self.max_mel_tokens, + self.max_text_tokens, + self.max_prompt_tokens, + checkpointing, + ) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = ml.Linear(model_dim, self.number_text_tokens) + self.mel_head = ml.Linear(model_dim, self.number_mel_codes) + + # Initialize the embeddings per the GPT-2 scheme + embeddings = [self.text_embedding] + if use_mel_codes_as_input: + embeddings.append(self.mel_embedding) + for module in embeddings: + module.weight.data.normal_(mean=0.0, std=0.02) + + if freeze_conditioning_encoder: + print(" > Freezing conditioning encoder.") + for p in self.conditioning_encoder.parameters(): + p.requires_grad = False + p.DO_NOT_TRAIN = True + + if freeze_everything_but_position_embeddings: + for p in self.parameters(): + p.requires_grad = False + p.DO_NOT_TRAIN = True + for m in [self.mel_pos_embedding, self.text_pos_embedding]: + for p in m.parameters(): + del p.DO_NOT_TRAIN + p.requires_grad = True + + def get_grad_norm_parameter_groups(self): + return { + "conditioning_encoder": list(self.conditioning_encoder.parameters()), + "gpt": list(self.gpt.parameters()), + "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), + } + + def post_init_gpt2_config(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False): + seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding + + if use_deepspeed: + # init deepspeed inference engine + if use_deepspeed_f16: + self.gpt.wte = self.mel_embedding.half() + self.gpt.wpe = self.mel_pos_embedding.half() + self.ds_engine = deepspeed.init_inference( + model=self.inference_model.half(), # Transformers models + mp_size=1, # Number of GPU + dtype=torch.float16 if use_deepspeed_f16 else torch.float32, # desired data type of output + replace_method="auto", # Lets DS autmatically identify the layer to replace + replace_with_kernel_inject=True, # replace the model with the kernel injector + ) + self.inference_model = self.ds_engine.module.eval() + + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, mel_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + # Set padding areas within MEL (currently it is coded with the MEL code for ). + for b in range(len(mel_lengths)): + actual_end = mel_lengths[b] + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_mel_token + return mel_input_tokens + + def get_logits( + self, + speech_conditioning_inputs, + first_inputs, + first_head, + second_inputs=None, + second_head=None, + prompt=None, + get_attns=False, + return_latent=False, + attn_mask_text=None, + attn_mask_mel=None, + ): + if prompt is not None and speech_conditioning_inputs is not None: + offset = speech_conditioning_inputs.shape[1] + prompt.shape[1] + if second_inputs is not None: + emb = torch.cat( + [speech_conditioning_inputs, prompt, first_inputs, second_inputs], + dim=1, + ) + else: + emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1) + elif speech_conditioning_inputs is not None: + offset = speech_conditioning_inputs.shape[1] + if second_inputs is not None: + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) + elif prompt is not None: + offset = prompt.shape[1] + if second_inputs is not None: + emb = torch.cat([prompt, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([prompt, first_inputs], dim=1) + + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + attn_mask = None + if attn_mask_text is not None: + attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) + if prompt is not None: + attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) + attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1) + + gpt_out = self.gpt( + inputs_embeds=emb, + return_dict=True, + output_attentions=get_attns, + attention_mask=attn_mask, + ) + + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, offset:] + enc = self.final_norm(enc) + + if return_latent: + return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :] + + first_logits = enc[:, : first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1] :] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def get_prompts(self, prompt_codes): + """ + Create a prompt from the mel codes. This is used to condition the model on the mel codes. + Pad the prompt with start and stop mel tokens. + """ + prompt = prompt_codes + if self.training: + prompt_len = random.randint(1, 9) # in secs + prompt_len = prompt_len * 24 # in frames + + if prompt_codes.shape[1] < prompt_len: + prompt_len = prompt_codes.shape[-1] + start = 0 + else: + start = random.randint(0, prompt_codes.shape[-1] - prompt_len) + + prompt = prompt_codes[:, start : start + prompt_len] + + # add start and stop tokens + prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token) + prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) + return prompt + + # def get_prompts(self, prompt_codes): + # """ + # Create a prompt from the mel codes. This is used to condition the model on the mel codes. + # Pad the prompt with start and stop mel tokens. + # """ + # prompt = prompt_codes + # if self.training: + # max_prompt_len = 9 * 24 + # if prompt_codes.shape[1] < max_prompt_len: + # prompt = prompt_codes + # else: + # start = random.randint(0, prompt_codes.shape[1] - max_prompt_len) + # prompt = prompt_codes[:, start : start + max_prompt_len] + + # # add start and stop tokens + # prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token) + # prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) + # return prompt + + def forward( + self, + speech_conditioning_input, + text_inputs, + text_lengths, + mel_codes, + wav_lengths, + prompt_codes, + loss_weights=None, + text_first=True, + return_attentions=False, + return_latent=False, + ): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + speech_conditioning_input: MEL float tensor, (b,80,s) + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + """ + + # ❗ FIXIT + speech_conditioning_input = None + if self.max_conditioning_inputs == 0: + assert ( + speech_conditioning_input is None + ), " ❗ speech_conditioning_input is not None, but max_conditioning_inputs == 0" + + max_text_len = text_lengths.max() + # Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values + # Like [..., 186, 45, 45, 83] where actually it should end with 186. + # We take last 3 codes to prevent abrupt ending of the audio. + # TODO: This is might need some testing. + mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3 + + # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. + max_mel_len = mel_lengths.max() + + if max_mel_len > mel_codes.shape[-1]: + mel_codes = F.pad(mel_codes, (0, max_mel_len - mel_codes.shape[-1])) + + # mel_lengths[mel_lengths >= max_mel_len] = max_mel_len + + # silence aware lengths, skip the silence tokens at the end of the mel codes. + silence = True + for idx, l in enumerate(mel_lengths): + length = l.item() + while silence: + if mel_codes[idx, length - 1] != 83: + break + length -= 1 + mel_lengths[idx] = length + + # Lovely assertions + assert ( + max_mel_len <= mel_codes.shape[-1] + ), f" ❗ max_mel_len ({max_mel_len}) > mel_codes.shape[-1] ({mel_codes.shape[-1]})" + assert ( + max_text_len <= text_inputs.shape[-1] + ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})" + + # Append stop token to text inputs + text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) + + # Append silence token to mel codes + mel_codes = F.pad(mel_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token) + + # Pad mel codes with STOP_MEL_TOKEN + mel_codes = self.set_mel_padding(mel_codes, mel_lengths) + + # Compute speech conditioning input + conds = None + if speech_conditioning_input is not None: + if not return_latent: + # Compute speech conditioning input + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) + else: + # already computed + conds = speech_conditioning_input.unsqueeze(1) + + # Build input and target tensors + # Prepend start token to inputs and append stop token to targets + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token + ) + + # Set attn_mask + attn_mask_text = None + attn_mask_mel = None + if not return_latent: + attn_mask_text = torch.ones( + text_inputs.shape[0], + text_inputs.shape[1], + dtype=torch.bool, + device=text_inputs.device, + ) + attn_mask_mel = torch.ones( + mel_codes.shape[0], + mel_codes.shape[1], + dtype=torch.bool, + device=mel_codes.device, + ) + + for idx, l in enumerate(text_lengths): + attn_mask_text[idx, l + 1 :] = 0.0 + + for idx, l in enumerate(mel_lengths): + attn_mask_mel[idx, l + 1 :] = 0.0 + + # Compute text embeddings + positional embeddings + # print(" > text input latent:", text_inputs) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + # Compute mel embeddings + positional embeddings + mel_emb = self.mel_embedding(mel_codes) + self.mel_pos_embedding(mel_codes) + + # Compute prompt embeddings + positional embeddings + prompt = self.get_prompts(prompt_codes) + + prompt_emb = self.mel_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach() + + # Get logits + sub = -4 # don't ask me why 😄 + if self.training: + sub = -1 + text_logits, mel_logits = self.get_logits( + conds, + text_emb, + self.text_head, + mel_emb, + self.mel_head, + prompt=prompt_emb, + get_attns=return_attentions, + return_latent=return_latent, + attn_mask_text=attn_mask_text, + attn_mask_mel=attn_mask_mel, + ) + if return_latent: + return mel_logits[:, :sub] # sub to prevent bla. + + if return_attentions: + return mel_logits + + # Set paddings to -1 to ignore them in loss + for idx, l in enumerate(text_lengths): + text_targets[idx, l + 1 :] = -1 + + for idx, l in enumerate(mel_lengths): + mel_targets[idx, l + 1 :] = -1 + + # check if stoptoken is in every row of mel_targets + assert (mel_targets == self.stop_mel_token).sum() >= mel_targets.shape[ + 0 + ], f" ❗ mel_targets does not contain stop token ({self.stop_mel_token}) in every row." + + # Compute losses + loss_text = F.cross_entropy( + text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing + ) + loss_mel = F.cross_entropy( + mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing + ) + + # if loss_weights is not None: + # loss_text = loss_text * loss_weights[:, None] + # loss_mel = loss_mel * loss_weights[:, None] + return loss_text.mean(), loss_mel.mean(), mel_logits + + def text_forward(self, speech_conditioning_input, text_inputs, text_lengths): + """ + Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the + model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided). + """ + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) + + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) + + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding + text_logits = self.get_logits(conds, text_emb, self.text_head) + loss_text = F.cross_entropy(text_logits, text_targets.long()) + return loss_text.mean() + + def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None): + """ + Performs autoregressive modeling on only speech data. + """ + assert self.max_mel_tokens >= mel_codes.shape[1], f"{mel_codes.shape[1]}" + + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = F.pad(mel_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token) + mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + if raw_mels is not None: + raw_mels = raw_mels[:, :, : max_mel_len * 4] + + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + if self.average_conditioning_embeddings: + conds = conds.mean(dim=1).unsqueeze(1) + + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token + ) + if raw_mels is not None: + mel_inp = F.pad(raw_mels, (0, 4)) + else: + mel_inp = mel_codes + mel_emb = self.mel_embedding(mel_inp) + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding + mel_logits = self.get_logits(conds, mel_emb, self.mel_head) + loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) + return loss_mel.mean() + + def get_generator(self, fake_inputs, **hf_generate_kwargs): + return self.inference_model.generate_stream( + fake_inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, + do_stream=True, + **hf_generate_kwargs, + ) + + def compute_embeddings( + self, + speech_conditioning_latent, + text_inputs, + input_tokens=None, + prompt_codes=None, + pad_input_text=False, + ): + if pad_input_text and text_inputs.shape[1] < 250: + text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token) + else: + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + print(" > Text inputs:", text_inputs) + if prompt_codes is not None: + prompt_codes = self.get_prompts(prompt_codes) + prompt_emb = self.mel_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes) + print(" > Prompt inputs:", prompt_codes) + print(" > Prompt inputs shape:", prompt_codes.shape) + emb = torch.cat([prompt_emb, emb], dim=1) + + if speech_conditioning_latent is not None: + conds = speech_conditioning_latent.unsqueeze(1) + emb = torch.cat([conds, emb], dim=1) + + self.inference_model.store_prefix_emb(emb) + + fake_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_mel_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + fake_inputs[:, -1] = self.start_mel_token + + if input_tokens is not None: + fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1) + return fake_inputs + + def inference_speech( + self, + speech_conditioning_latent, + text_inputs, + input_tokens=None, + prompt_codes=None, + pad_input_text=False, + **hf_generate_kwargs, + ): + if pad_input_text and text_inputs.shape[1] < 250: + text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token) + else: + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + print(" > Text inputs:", text_inputs) + if prompt_codes is not None: + prompt_codes = self.get_prompts(prompt_codes) + prompt_emb = self.mel_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes) + print(" > Prompt inputs:", prompt_codes) + print(" > Prompt inputs shape:", prompt_codes.shape) + emb = torch.cat([prompt_emb, emb], dim=1) + + if speech_conditioning_latent is not None: + conds = speech_conditioning_latent.unsqueeze(1) + emb = torch.cat([conds, emb], dim=1) + + self.inference_model.store_prefix_emb(emb) + + fake_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_mel_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + fake_inputs[:, -1] = self.start_mel_token + + if input_tokens is not None: + fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1) + + gen = self.inference_model.generate( + fake_inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens, + **hf_generate_kwargs, + ) + if "return_dict_in_generate" in hf_generate_kwargs: + return gen.sequences[:, fake_inputs.shape[1] :], gen + return gen[:, fake_inputs.shape[1] :] + + # Turns the (utterly insane) output of HF.generate() into a far more sane output: + # [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence + def make_hf_generate_attentions_sane(self, attentions): + layers = [[] for _ in range(len(attentions[0]))] + full_attention_size = attentions[-1][0].shape[-1] + for i, gen in enumerate(attentions): + for j, lyr in enumerate(gen): + layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1]))) + catted = [] + for lyr in layers: + catted.append(torch.cat(lyr, dim=2)) + return catted + + def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds): + """ + This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice. + """ + text_padding = num_conds + 2 + num_text = text.shape[-1] + num_context = num_text + text_padding + assert num_context + 1 == attentions[0][0].shape[-1] + attentions = self.make_hf_generate_attentions_sane(attentions) + results = [torch.empty_like(codes) for _ in range(len(attentions))] + for l, layer in enumerate(attentions): + dec_context = layer[:, :, num_context:, :] + # Mask out everything that isn't text (including the start token, which gets a LOT of attention) + dec_context[:, :, :, : text_padding + 1] = 0 + dec_context[:, :, :, num_context:] = 0 + for h in range(dec_context.shape[1]): + dec_context_indices = torch.argmax(dec_context[0, h], dim=-1) + print(f"layer_{l};head_{h}: " + str(dec_context_indices)) + for t, att_tok in enumerate(attentions): + combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device) + for lyr in att_tok: + token_to_text_attentions = lyr[:, :, -1, text_padding : (text_padding + num_text)].sum(dim=1) + combined_attention_weights = combined_attention_weights + token_to_text_attentions + break + most_attended_text_token = combined_attention_weights.argmax(dim=-1) + results[:, t] = most_attended_text_token + eos_token_mask = codes != self.stop_mel_token + return results * eos_token_mask + + +@register_model +def register_unified_voice_prompt(opt_net, opt): + return UnifiedVoice(**opt_get(opt_net, ["kwargs"], {})) + + +if __name__ == "__main__": + gpt = UnifiedVoice( + model_dim=256, + heads=4, + train_solo_embeddings=True, + use_mel_codes_as_input=True, + max_conditioning_inputs=4, + freeze_everything_but_position_embeddings=True, + ) + l = gpt( + torch.randn(2, 3, 80, 800), + torch.randint(high=256, size=(2, 120)), + torch.tensor([32, 120]), + torch.randint(high=8192, size=(2, 250)), + torch.tensor([250 * 256, 195 * 256]), + ) + # gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80])) diff --git a/TTS/tts/layers/xtts/gpt_inference.py b/TTS/tts/layers/xtts/gpt_inference.py new file mode 100644 index 00000000..d44bd3de --- /dev/null +++ b/TTS/tts/layers/xtts/gpt_inference.py @@ -0,0 +1,136 @@ +import math + +import torch +from torch import nn +from transformers import GPT2PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + +class GPT2InferenceModel(GPT2PreTrainedModel): + """Override GPT2LMHeadModel to allow for prefix conditioning.""" + + def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): + super().__init__(config) + self.transformer = gpt + self.pos_embedding = pos_emb + self.embeddings = embeddings + self.final_norm = norm + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + def store_prefix_emb(self, prefix_emb): + self.cached_prefix_emb = prefix_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values is not None: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_prefix_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1] + + # Create embedding + prefix_len = self.cached_prefix_emb.shape[1] + if input_ids.shape[1] != 1: + gen_inputs = input_ids[:, prefix_len:] + gen_emb = self.embeddings(gen_inputs) + gen_emb = gen_emb + self.pos_embedding(gen_emb) + if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]: + prefix_emb = self.cached_prefix_emb.repeat_interleave( + gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0 + ) + else: + prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype) + emb = torch.cat([prefix_emb, gen_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - (prefix_len + 1), attention_mask.device + ) + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) diff --git a/TTS/tts/layers/xtts/latent_encoder.py b/TTS/tts/layers/xtts/latent_encoder.py new file mode 100644 index 00000000..f9d62a36 --- /dev/null +++ b/TTS/tts/layers/xtts/latent_encoder.py @@ -0,0 +1,141 @@ +# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts + +import math + +import torch +from torch import nn +from torch.nn import functional as F + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def normalization(channels): + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +class QKVAttention(nn.Module): + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, qk_bias=0): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = weight + qk_bias + if mask is not None: + mask = mask.repeat(self.n_heads, 1, 1) + weight[mask.logical_not()] = -torch.inf + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class AttentionBlock(nn.Module): + """An attention block that allows spatial positions to attend to each other.""" + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + out_channels=None, + do_activation=False, + ): + super().__init__() + self.channels = channels + out_channels = channels if out_channels is None else out_channels + self.do_activation = do_activation + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, out_channels * 3, 1) + self.attention = QKVAttention(self.num_heads) + + self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1) + self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1)) + + def forward(self, x, mask=None, qk_bias=0): + b, c, *spatial = x.shape + if mask is not None: + if len(mask.shape) == 2: + mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1) + if mask.shape[1] != x.shape[-1]: + mask = mask[:, : x.shape[-1], : x.shape[-1]] + + x = x.reshape(b, c, -1) + x = self.norm(x) + if self.do_activation: + x = F.silu(x, inplace=True) + qkv = self.qkv(x) + h = self.attention(qkv, mask=mask, qk_bias=qk_bias) + h = self.proj_out(h) + xp = self.x_proj(x) + return (xp + h).reshape(b, xp.shape[1], *spatial) + + +class ConditioningEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + ): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + """ + x: (b, 80, s) + """ + h = self.init(x) + h = self.attn(h) + return h diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py new file mode 100644 index 00000000..0fad8133 --- /dev/null +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -0,0 +1,286 @@ +import json +import os +import re + +import inflect +import pandas as pd +import pypinyin +import torch +from num2words import num2words +from tokenizers import Tokenizer +from unidecode import unidecode + +from TTS.tts.utils.text.cleaners import english_cleaners + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + text = text.replace('"', "") + return text + + +def expand_numbers_multilang(text, lang): + # TODO: Handle text more carefully. Currently, it just converts numbers without any context. + # Find all numbers in the input string + numbers = re.findall(r"\d+", text) + + # Transliterate the numbers to text + for num in numbers: + transliterated_num = "".join(num2words(num, lang=lang)) + text = text.replace(num, transliterated_num, 1) + + return text + + +def transliteration_cleaners(text): + """Pipeline for non-English text that transliterates to ASCII.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def multilingual_cleaners(text, lang): + text = lowercase(text) + text = expand_numbers_multilang(text, lang) + text = collapse_whitespace(text) + text = text.replace('"', "") + if lang == "tr": + text = text.replace("İ", "i") + text = text.replace("Ö", "ö") + text = text.replace("Ü", "ü") + return text + + +def english_cleaners(text): + """Pipeline for English text, including number and abbreviation expansion.""" + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + text = text.replace('"', "") + return text + + +def remove_extraneous_punctuation(word): + replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "—": "-", "—": "-", "`": "'", "ʼ": "'"} + replace = re.compile( + "|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL + ) + word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word) + + # TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners. + extraneous = re.compile(r"^[@#%_=\$\^&\*\+\\]$") + word = extraneous.sub("", word) + return word + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +_whitespace_re = re.compile(r"\s+") + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def arabic_cleaners(text): + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def chinese_cleaners(text): + text = lowercase(text) + text = "".join( + [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)] + ) + return text + + +class VoiceBpeTokenizer: + def __init__(self, vocab_file=None, preprocess=None): + self.tokenizer = None + + if vocab_file is not None: + with open(vocab_file, "r", encoding="utf-8") as f: + vocab = json.load(f) + + self.language = vocab["model"]["language"] if "language" in vocab["model"] else None + + if preprocess is None: + self.preprocess = "pre_tokenizer" in vocab and vocab["pre_tokenizer"] + else: + self.preprocess = preprocess + + self.tokenizer = Tokenizer.from_file(vocab_file) + + def preprocess_text(self, txt, lang): + if lang == "ja": + import pykakasi + + kks = pykakasi.kakasi() + results = kks.convert(txt) + txt = " ".join([result["kana"] for result in results]) + txt = basic_cleaners(txt) + elif lang == "en": + txt = english_cleaners(txt) + elif lang == "ar": + txt = arabic_cleaners(txt) + elif lang == "zh-cn": + txt = chinese_cleaners(txt) + else: + txt = multilingual_cleaners(txt, lang) + return txt + + def encode(self, txt, lang): + if self.preprocess: + txt = self.preprocess_text(txt, lang) + txt = txt.replace(" ", "[SPACE]") + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "") + txt = txt.replace("[SPACE]", " ") + txt = txt.replace("[STOP]", "") + txt = txt.replace("[UNK]", "") + return txt diff --git a/TTS/tts/layers/xtts/vocoder.py b/TTS/tts/layers/xtts/vocoder.py new file mode 100644 index 00000000..0f4991b8 --- /dev/null +++ b/TTS/tts/layers/xtts/vocoder.py @@ -0,0 +1,385 @@ +import json +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +MAX_WAV_VALUE = 32768.0 + + +class KernelPredictor(torch.nn.Module): + """Kernel predictor for the location-variable convolutions""" + + def __init__( + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + """ + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + """ + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_kernel_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + self.bias_conv = nn.utils.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_bias_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + """ + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.input_conv[0]) + nn.utils.remove_weight_norm(self.kernel_conv) + nn.utils.remove_weight_norm(self.bias_conv) + for block in self.residual_convs: + nn.utils.remove_weight_norm(block[1]) + nn.utils.remove_weight_norm(block[3]) + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, + ) + + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm( + nn.ConvTranspose1d( + in_channels, + in_channels, + 2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ) + ), + ) + + self.conv_blocks = nn.ModuleList() + for dilation in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm( + nn.Conv1d( + in_channels, + in_channels, + conv_kernel_size, + padding=dilation * (conv_kernel_size - 1) // 2, + dilation=dilation, + ) + ), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + kernels, bias = self.kernel_predictor(c) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution( + output, k, b, hop_size=self.cond_hop_length + ) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :] + ) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o + + def remove_weight_norm(self): + self.kernel_predictor.remove_weight_norm() + nn.utils.remove_weight_norm(self.convt_pre[1]) + for block in self.conv_blocks: + nn.utils.remove_weight_norm(block[1]) + + +class UnivNetGenerator(nn.Module): + """ + UnivNet Generator + + Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py. + """ + + def __init__( + self, + noise_dim=64, + channel_size=32, + dilations=[1, 3, 9, 27], + strides=[8, 8, 4], + lReLU_slope=0.2, + kpnet_conv_size=3, + # Below are MEL configurations options that this generator requires. + hop_length=256, + n_mel_channels=100, + ): + super(UnivNetGenerator, self).__init__() + self.mel_channel = n_mel_channels + self.noise_dim = noise_dim + self.hop_length = hop_length + channel_size = channel_size + kpnet_conv_size = kpnet_conv_size + + self.res_stack = nn.ModuleList() + hop_length = 1 + for stride in strides: + hop_length = stride * hop_length + self.res_stack.append( + LVCBlock( + channel_size, + n_mel_channels, + stride=stride, + dilations=dilations, + lReLU_slope=lReLU_slope, + cond_hop_length=hop_length, + kpnet_conv_size=kpnet_conv_size, + ) + ) + + self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")) + + self.conv_post = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")), + nn.Tanh(), + ) + + def forward(self, c, z): + """ + Args: + c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) + z (Tensor): the noise sequence (batch, noise_dim, in_length) + + """ + z = self.conv_pre(z) # (B, c_g, L) + + for res_block in self.res_stack: + res_block.to(z.device) + z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) + + z = self.conv_post(z) # (B, 1, L * 256) + + return z + + def eval(self, inference=False): + super(UnivNetGenerator, self).eval() + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv_pre) + + for layer in self.conv_post: + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + + for res_block in self.res_stack: + res_block.remove_weight_norm() + + def inference(self, c, z=None): + # pad input mel with zeros to cut artifact + # see https://github.com/seungwonpark/melgan/issues/8 + zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) + mel = torch.cat((c, zero), dim=2) + + if z is None: + z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) + + audio = self.forward(mel, z) + audio = audio[:, :, : -(self.hop_length * 10)] + audio = audio.clamp(min=-1, max=1) + return audio + + +if __name__ == "__main__": + model = UnivNetGenerator() + + c = torch.randn(3, 100, 10) + z = torch.randn(3, 64, 10) + print(c.shape) + + y = model(c, z) + print(y.shape) + assert y.shape == torch.Size([3, 1, 2560]) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py new file mode 100644 index 00000000..0836870e --- /dev/null +++ b/TTS/tts/models/xtts.py @@ -0,0 +1,654 @@ +import os +from contextlib import contextmanager +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +import torchaudio +from coqpit import Coqpit + +from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel +from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts +from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps +from TTS.tts.layers.xtts.gpt import GPT +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.xtts.vocoder import UnivNetGenerator +from TTS.tts.models.base_tts import BaseTTS +from TTS.utils.io import load_fsspec + + +def load_audio(audiopath, sr=22050): + """ + Load an audio file from disk and resample it to the specified sampling rate. + + Args: + audiopath (str): Path to the audio file. + sr (int): Target sampling rate. + + Returns: + Tensor: Audio waveform tensor with shape (1, T), where T is the number of samples. + """ + audio, sampling_rate = torchaudio.load(audiopath) + + if len(audio.shape) > 1: + if audio.shape[0] < 5: + audio = audio[0] + else: + assert audio.shape[1] < 5 + audio = audio[:, 0] + + if sampling_rate != sr: + resampler = torchaudio.transforms.Resample(sampling_rate, sr) + audio = resampler(audio) + + audio = audio.clamp_(-1, 1) + return audio.unsqueeze(0) + + +def wav_to_mel_cloning( + wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu") +): + """ + Convert waveform to mel-spectrogram with hard-coded parameters for cloning. + + Args: + wav (torch.Tensor): Input waveform tensor. + mel_norms_file (str): Path to mel-spectrogram normalization file. + mel_norms (torch.Tensor): Mel-spectrogram normalization tensor. + device (torch.device): Device to use for computation. + + Returns: + torch.Tensor: Mel-spectrogram tensor. + """ + mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + norm="slaney", + ).to(device) + wav = wav.to(device) + mel = mel_stft(wav) + mel = torch.log(torch.clamp(mel, min=1e-5)) + if mel_norms is None: + mel_norms = torch.load(mel_norms_file, map_location=device) + mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +def pad_or_truncate(t, length): + """ + Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. + + Args: + t (torch.Tensor): The input tensor to be padded or truncated. + length (int): The desired length of the tensor. + + Returns: + torch.Tensor: The padded or truncated tensor. + """ + tp = t[..., :length] + if t.shape[-1] == length: + tp = t + elif t.shape[-1] < length: + tp = F.pad(t, (0, length - t.shape[-1])) + return tp + + +def load_discrete_vocoder_diffuser( + trained_diffusion_steps=4000, + desired_diffusion_steps=200, + cond_free=True, + cond_free_k=1, + sampler="ddim", +): + """ + Load a GaussianDiffusion instance configured for use as a decoder. + + Args: + trained_diffusion_steps (int): The number of diffusion steps used during training. + desired_diffusion_steps (int): The number of diffusion steps to use during inference. + cond_free (bool): Whether to use a conditioning-free model. + cond_free_k (int): The number of samples to use for conditioning-free models. + sampler (str): The name of the sampler to use. + + Returns: + A SpacedDiffusion instance configured with the given parameters. + """ + return SpacedDiffusion( + use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), + model_mean_type="epsilon", + model_var_type="learned_range", + loss_type="mse", + betas=get_named_beta_schedule("linear", trained_diffusion_steps), + conditioning_free=cond_free, + conditioning_free_k=cond_free_k, + sampler=sampler, + ) + + +def do_spectrogram_diffusion( + diffusion_model, + diffuser, + latents, + conditioning_latents, + temperature=1, +): + """ + Generate a mel-spectrogram using a diffusion model and a diffuser. + + Args: + diffusion_model (nn.Module): A diffusion model that converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + diffuser (Diffuser): A diffuser that generates a mel-spectrogram from noise. + latents (torch.Tensor): A tensor of shape (batch_size, seq_len, code_size) containing the input spectrogram codes. + conditioning_latents (torch.Tensor): A tensor of shape (batch_size, code_size) containing the conditioning codes. + temperature (float, optional): The temperature of the noise used by the diffuser. Defaults to 1. + + Returns: + torch.Tensor: A tensor of shape (batch_size, mel_channels, mel_seq_len) containing the generated mel-spectrogram. + """ + with torch.no_grad(): + output_seq_len = ( + latents.shape[1] * 4 * 24000 // 22050 + ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent( + latents, conditioning_latents, output_seq_len, False + ) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.sample_loop( + diffusion_model, + output_shape, + noise=noise, + model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, + progress=False, + ) + return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] + + +@dataclass +class XttsAudioConfig(Coqpit): + """ + Configuration class for audio-related parameters in the XTTS model. + + Args: + sample_rate (int): The sample rate in which the GPT operates. + diffusion_sample_rate (int): The sample rate of the diffusion audio waveform. + output_sample_rate (int): The sample rate of the output audio waveform. + """ + + sample_rate: int = 22050 + diffusion_sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +@dataclass +class XttsArgs(Coqpit): + """A dataclass to represent XTTS model arguments that define the model structure. + + Args: + gpt_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. + + For GPT model: + ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + ar_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70. + ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + ar_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + ar_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + + For DiffTTS model: + diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024. + diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10. + diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100. + diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200. + diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024. + diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193. + diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0. + diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False. + diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16. + diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0. + diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0. + """ + + gpt_batch_size: int = 1 + enable_redaction: bool = False + lazy_load: bool = True + kv_cache: bool = True + gpt_checkpoint: str = None + clvp_checkpoint: str = None + decoder_checkpoint: str = None + num_chars: int = 255 + + # XTTS GPT Encoder params + tokenizer_file: str = "" + gpt_max_audio_tokens: int = 605 + gpt_max_text_tokens: int = 402 + gpt_max_prompt_tokens: int = 70 + gpt_layers: int = 30 + gpt_n_model_channels: int = 1024 + gpt_n_heads: int = 16 + gpt_number_text_tokens: int = None + gpt_start_text_token: int = None + gpt_stop_text_token: int = None + gpt_num_audio_tokens: int = 8194 + gpt_start_audio_token: int = 8192 + gpt_stop_audio_token: int = 8193 + + # Diffusion Decoder params + diff_model_channels: int = 1024 + diff_num_layers: int = 10 + diff_in_channels: int = 100 + diff_out_channels: int = 200 + diff_in_latent_channels: int = 1024 + diff_in_tokens: int = 8193 + diff_dropout: int = 0 + diff_use_fp16: bool = False + diff_num_heads: int = 16 + diff_layer_drop: int = 0 + diff_unconditioned_percentage: int = 0 + + # constants + duration_const: int = 102400 + + +class Xtts(BaseTTS): + """ⓍTTS model implementation. + + ❗ Currently it only supports inference. + + Examples: + >>> from TTS.tts.configs.xtts_config import XttsConfig + >>> from TTS.tts.models.xtts import Xtts + >>> config = XttsConfig() + >>> model = Xtts.inif_from_config(config) + >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + """ + + def __init__(self, config: Coqpit): + super().__init__(config, ap=None, tokenizer=None) + self.lazy_load = self.args.lazy_load + self.mel_stats_path = None + self.config = config + self.gpt_checkpoint = self.args.gpt_checkpoint + self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed + self.models_dir = config.model_dir + self.gpt_batch_size = self.args.gpt_batch_size + + self.tokenizer = VoiceBpeTokenizer() + self.gpt = None + self.diffusion_decoder = None + self.init_models() + self.register_buffer("mel_stats", torch.ones(80)) + + def init_models(self): + """Initialize the models. We do it here since we need to load the tokenizer first.""" + if self.tokenizer.tokenizer is not None: + self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size() + self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") + self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") + + if self.args.gpt_number_text_tokens: + self.gpt = GPT( + layers=self.args.gpt_layers, + model_dim=self.args.gpt_n_model_channels, + start_text_token=self.args.gpt_start_text_token, + stop_text_token=self.args.gpt_stop_text_token, + heads=self.args.gpt_n_heads, + max_text_tokens=self.args.gpt_max_text_tokens, + max_mel_tokens=self.args.gpt_max_audio_tokens, + max_prompt_tokens=self.args.gpt_max_prompt_tokens, + number_text_tokens=self.args.gpt_number_text_tokens, + num_audio_tokens=self.args.gpt_num_audio_tokens, + start_audio_token=self.args.gpt_start_audio_token, + stop_audio_token=self.args.gpt_stop_audio_token, + ) + + self.diffusion_decoder = DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ) + + self.vocoder = UnivNetGenerator() + + @property + def device(self): + return next(self.parameters()).device + + @contextmanager + def lazy_load_model(self, model): + """Context to load a model on demand. + + Args: + model (nn.Module): The model to be loaded. + """ + if self.lazy_load: + yield model + else: + m = model.to(self.device) + yield m + m = model.cpu() + + def get_gpt_cond_latents(self, audio_path: str, length: int = 3): + """Compute the conditioning latents for the GPT model from the given audio. + + Args: + audio_path (str): Path to the audio file. + length (int): Length of the audio in seconds. Defaults to 3. + """ + + audio = load_audio(audio_path) + audio = audio[:, : 22050 * length] + mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu()) + cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False) + return cond_latent.transpose(1, 2) + + def get_diffusion_cond_latents( + self, + audio_path, + ): + from math import ceil + + diffusion_conds = [] + CHUNK_SIZE = 102400 + audio = load_audio(audio_path, 24000) + for chunk in range(ceil(audio.shape[1] / CHUNK_SIZE)): + current_sample = audio[:, chunk * CHUNK_SIZE : (chunk + 1) * CHUNK_SIZE] + current_sample = pad_or_truncate(current_sample, CHUNK_SIZE) + cond_mel = wav_to_univnet_mel( + current_sample.to(self.device), + do_normalization=False, + device=self.device, + ) + diffusion_conds.append(cond_mel) + diffusion_conds = torch.stack(diffusion_conds, dim=1) + with self.lazy_load_model(self.diffusion_decoder) as diffusion: + diffusion_latent = diffusion.get_conditioning(diffusion_conds) + return diffusion_latent + + def get_conditioning_latents( + self, + audio_path, + gpt_cond_len=3, + ): + gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T] + diffusion_cond_latents = self.get_diffusion_cond_latents( + audio_path, + ) + return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device) + + def synthesize(self, text, config, speaker_wav, language, **kwargs): + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (XttsConfig): Config with inference parameters. + speaker_wav (str): Path to the speaker audio file for cloning. + language (str): Language ID of the speaker. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + + # Make the synthesizer happy 🥳 + if isinstance(speaker_wav, list): + speaker_wav = speaker_wav[0] + + return self.inference_with_config(text, config, ref_audio_path=speaker_wav, language=language, **kwargs) + + def inference_with_config(self, text, config, ref_audio_path, language, **kwargs): + """ + inference with config + """ + assert ( + language in self.config.languages + ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}" + # Use generally found best tuning knobs for generation. + settings = { + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_k": config.top_k, + "top_p": config.top_p, + "cond_free_k": config.cond_free_k, + "diffusion_temperature": config.diffusion_temperature, + "decoder_iterations": config.decoder_iterations, + "decoder_sampler": config.decoder_sampler, + } + settings.update(kwargs) # allow overriding of preset settings with kwargs + return self.inference(text, ref_audio_path, language, **settings) + + @torch.no_grad() + def inference( + self, + text, + ref_audio_path, + language, + # GPT inference + temperature=0.65, + length_penalty=1, + repetition_penalty=2.0, + top_k=50, + top_p=0.85, + gpt_cond_len=4, + do_sample=True, + # Decoder inference + decoder_iterations=100, + cond_free=True, + cond_free_k=2, + diffusion_temperature=1.0, + decoder_sampler="ddim", + **hf_generate_kwargs, + ): + """ + This function produces an audio clip of the given text being spoken with the given reference voice. + + Args: + text: (str) Text to be spoken. + + ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3 + seconds long. + + language: (str) Language of the voice to be generated. + + temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65. + + length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the + model to produce more terse outputs. Defaults to 1.0. + + repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during + decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0. + + top_k: (int) K value used in top-k sampling. [0,inf]. Lower values mean the decoder produces more "likely" + (aka boring) outputs. Defaults to 50. + + top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" + (aka boring) outputs. Defaults to 0.8. + + gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used + else the first `gpt_cond_len` secs is used. Defaults to 3 seconds. + + decoder_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has + more chances to iteratively refine the output, which should theoretically mean a higher quality output. + Generally a value above 250 is not noticeably better, however. Defaults to 100. + + cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion + performs two forward passes for each diffusion step: one with the outputs of the autoregressive model + and one with no conditioning priors. The output of the two is blended according to the cond_free_k + value below. Conditioning-free diffusion is the real deal, and dramatically improves realism. + Defaults to True. + + cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the + conditioning-present signal. [0,inf]. As cond_free_k increases, the output becomes dominated by the + conditioning-free signal. Defaults to 2.0. + + diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1]. + Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared. + Defaults to 1.0. + + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive + transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + + Returns: + Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + text = f"[{language}]{text.strip().lower()}" + text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + ( + gpt_cond_latent, + diffusion_conditioning, + ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len) + + diffuser = load_discrete_vocoder_diffuser( + desired_diffusion_steps=decoder_iterations, + cond_free=cond_free, + cond_free_k=cond_free_k, + sampler=decoder_sampler, + ) + + with torch.no_grad(): + self.gpt = self.gpt.to(self.device) + with self.lazy_load_model(self.gpt) as gpt: + gpt_codes = gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + + with self.lazy_load_model(self.gpt) as gpt: + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + silence_token = 83 + ctokens = 0 + for k in range(gpt_codes.shape[-1]): + if gpt_codes[0, k] == silence_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: + gpt_latents = gpt_latents[:, :k] + break + + with self.lazy_load_model(self.diffusion_decoder) as diffusion: + mel = do_spectrogram_diffusion( + diffusion, + diffuser, + gpt_latents, + diffusion_conditioning, + temperature=diffusion_temperature, + ) + with self.lazy_load_model(self.vocoder) as vocoder: + wav = vocoder.inference(mel) + + return {"wav": wav.cpu().numpy().squeeze()} + + def forward(self): + raise NotImplementedError("XTTS Training is not implemented") + + def eval_step(self): + raise NotImplementedError("XTTS Training is not implemented") + + @staticmethod + def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument + return Xtts(config) + + def eval(self): # pylint: disable=redefined-builtin + """Sets the model to evaluation mode. Overrides the default eval() method to also set the GPT model to eval mode.""" + self.gpt.init_gpt_for_inference() + super().eval() + + def load_checkpoint( + self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True + ): + """ + Loads a checkpoint from disk and initializes the model's state and tokenizer. + + Args: + config (dict): The configuration dictionary for the model. + checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. + checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. + vocab_path (str, optional): The path to the vocabulary file. Defaults to None. + eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False. + strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. + + Returns: + None + """ + + model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") + vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") + + if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")): + self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json")) + + self.init_models() + if eval: + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + self.load_state_dict(load_fsspec(model_path)["model"], strict=strict) + + if eval: + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + self.gpt.eval() + self.diffusion_decoder.eval() + self.vocoder.eval() + + def train_step(self): + raise NotImplementedError("XTTS Training is not implemented") diff --git a/TTS/tts/utils/text/belarusian/phonemizer.py b/TTS/tts/utils/text/belarusian/phonemizer.py index 3c07a209..1922577e 100644 --- a/TTS/tts/utils/text/belarusian/phonemizer.py +++ b/TTS/tts/utils/text/belarusian/phonemizer.py @@ -8,7 +8,9 @@ def init(): import jpype import jpype.imports except ModuleNotFoundError: - raise ModuleNotFoundError("Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`.") + raise ModuleNotFoundError( + "Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`." + ) try: jar_path = os.environ["BEL_FANETYKA_JAR"] @@ -31,4 +33,5 @@ def belarusian_text_to_phonemes(text: str) -> str: init() from org.alex73.fanetyka.impl import FanetykaText + return str(FanetykaText(finder, text).ipa) diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py index 638184fd..f9a0340c 100644 --- a/TTS/tts/utils/text/phonemizers/__init__.py +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -1,6 +1,6 @@ from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer -from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut from TTS.tts.utils.text.phonemizers.ko_kr_phonemizer import KO_KR_Phonemizer diff --git a/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py b/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py index fb620766..e5fcab6e 100644 --- a/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py +++ b/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py @@ -1,7 +1,7 @@ from typing import Dict -from TTS.tts.utils.text.phonemizers.base import BasePhonemizer from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer _DEF_BE_PUNCS = ",!." # TODO diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index f8a11ac5..97305762 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import datetime import importlib +import logging import os import re import subprocess @@ -219,3 +220,22 @@ class KeepAverage: def update_values(self, value_dict): for key, value in value_dict.items(): self.update_value(key, value) + + +def get_timestamp(): + return datetime.now().strftime("%y%m%d-%H%M%S") + + +def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): + lg = logging.getLogger(logger_name) + formatter = logging.Formatter("%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S") + lg.setLevel(level) + if tofile: + log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) + fh = logging.FileHandler(log_file, mode="w") + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 70d35228..ed48758f 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -334,7 +334,7 @@ class ModelManager(object): output_model_path = output_path output_config_path = None if ( - model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name + model not in ["tortoise-v2", "bark", "xtts_v1"] and "fairseq" not in model_name ): # TODO:This is stupid but don't care for now. output_model_path, output_config_path = self._find_files(output_path) # update paths in the config.json diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index fbae3216..e6f35460 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -242,7 +242,11 @@ class Synthesizer(nn.Module): wav (List[int]): waveform as a list of values. path (str): output path to save the waveform. """ - wav = np.array(wav) + # if tensor convert to numpy + if torch.is_tensor(wav): + wav = wav.cpu().numpy() + if isinstance(wav, list): + wav = np.array(wav) save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate) def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]: @@ -334,7 +338,7 @@ class Synthesizer(nn.Module): elif language_name and isinstance(language_name, str): try: - language_id = self.tts_model.language_manager.name_to_id[language_name] + language_id = self.tts_model.language_manager.name_to_id[language_id] except KeyError as e: raise ValueError( f" [!] Looks like you use a multi-lingual model. " @@ -374,6 +378,8 @@ class Synthesizer(nn.Module): speaker_id=speaker_name, voice_dirs=self.voice_dir, d_vector=speaker_embedding, + speaker_wav=speaker_wav, + language=language_name, **kwargs, ) else: diff --git a/docs/source/index.md b/docs/source/index.md index 5ef3d88c..79993eec 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -53,6 +53,7 @@ models/overflow.md models/tortoise.md models/bark.md + models/xtts.md .. toctree:: :maxdepth: 2 diff --git a/docs/source/models/bark.md b/docs/source/models/bark.md index 4092d9f4..c328ae61 100644 --- a/docs/source/models/bark.md +++ b/docs/source/models/bark.md @@ -1,4 +1,4 @@ -# Bark 🐶 +# 🐶 Bark Bark is a multi-lingual TTS model created by [Suno-AI](https://www.suno.ai/). It can generate conversational speech as well as music and sound effects. It is architecturally very similar to Google's [AudioLM](https://arxiv.org/abs/2209.03143). For more information, please refer to the [Suno-AI's repo](https://github.com/suno-ai/bark). diff --git a/docs/source/models/tortoise.md b/docs/source/models/tortoise.md index d602d597..2df6da76 100644 --- a/docs/source/models/tortoise.md +++ b/docs/source/models/tortoise.md @@ -1,4 +1,4 @@ -# Tortoise 🐢 +# 🐢 Tortoise Tortoise is a very expressive TTS system with impressive voice cloning capabilities. It is based on an GPT like autogressive acoustic model that converts input text to discritized acouistic tokens, a diffusion model that converts these tokens to melspeectrogram frames and a Univnet vocoder to convert the spectrograms to the final audio signal. The important downside is that Tortoise is very slow compared to the parallel TTS models like VITS. diff --git a/docs/source/models/xtts.md b/docs/source/models/xtts.md new file mode 100644 index 00000000..85a3afba --- /dev/null +++ b/docs/source/models/xtts.md @@ -0,0 +1,108 @@ +# ⓍTTS +ⓍTTS is a super cool Text-to-Speech model that lets you clone voices in different languages by using just a quick 3-second audio clip. Built on the 🐢Tortoise, +ⓍTTS has important model changes that make cross-language voice cloning and multi-lingual speech generation super easy. +There is no need for an excessive amount of training data that spans countless hours. + +This is the same model that powers [Coqui Studio](https://coqui.ai/), and [Coqui API](https://docs.coqui.ai/docs), however we apply +a few tricks to make it faster and support streaming inference. + +### Features +- Voice cloning with just a 3-second audio clip. +- Cross-language voice cloning. +- Multi-lingual speech generation. +- 24khz sampling rate. + +### Code +Current implementation only supports inference. + +### Languages +As of now, XTTS-v1 supports 13 languages: English, Spanish, French, German, Italian, Portuguese, +Polish, Turkish, Russian, Dutch, Czech, Arabic, and Chinese (Simplified). + +Stay tuned as we continue to add support for more languages. If you have any language requests, please feel free to reach out. + +### License +This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml). + +### Contact +Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai). +You can also mail us at info@coqui.ai. + +Using 🐸TTS API: + +```python +from TTS.api import TTS +tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True) + +# generate speech by cloning a voice using default settings +tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + file_path="output.wav", + speaker_wav="/path/to/target/speaker.wav", + language="en") + +# generate speech by cloning a voice using custom settings +tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + file_path="output.wav", + speaker_wav="/path/to/target/speaker.wav", + language="en", + decoder_iterations=30) +``` + +Using 🐸TTS Command line: + +```console + tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \ + --text "Bugün okula gitmek istemiyorum." \ + --speaker_wav /path/to/target/speaker.wav \ + --language_idx tr \ + --use_cuda true +``` + +Using model directly: + +```python +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts + +config = XttsConfig() +config.load_json("/path/to/xtts/config.json") +model = Xtts.init_from_config(config) +model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True) +model.cuda() + +outputs = model.synthesize( + "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.", + config, + speaker_wav="/data/TTS-public/_refclips/3.wav", + gpt_cond_len=3, + language="en", +) +``` + + +## Important resources & papers +- VallE: https://arxiv.org/abs/2301.02111 +- Tortoise Repo: https://github.com/neonbjb/tortoise-tts +- Faster implementation: https://github.com/152334H/tortoise-tts-fast +- Univnet: https://arxiv.org/abs/2106.07889 +- Latent Diffusion:https://arxiv.org/abs/2112.10752 +- DALL-E: https://arxiv.org/abs/2102.12092 + + +## XttsConfig +```{eval-rst} +.. autoclass:: TTS.tts.configs.xtts_config.XttsConfig + :members: +``` + +## XttsArgs +```{eval-rst} +.. autoclass:: TTS.tts.models.xtts.XttsArgs + :members: +``` + +## XTTS Model +```{eval-rst} +.. autoclass:: TTS.tts.models.xtts.XTTS + :members: +``` diff --git a/recipes/bel-alex73/train_glowtts.py b/recipes/bel-alex73/train_glowtts.py index 24b62d79..74866be7 100644 --- a/recipes/bel-alex73/train_glowtts.py +++ b/recipes/bel-alex73/train_glowtts.py @@ -60,7 +60,7 @@ config = GlowTTSConfig( output_path=output_path, add_blank=True, datasets=[dataset_config], -# characters=characters, + # characters=characters, enable_eos_bos_chars=True, mixed_precision=False, save_step=10000, diff --git a/tests/text_tests/test_belarusian_phonemizer.py b/tests/text_tests/test_belarusian_phonemizer.py index 278ee8be..76ba4667 100644 --- a/tests/text_tests/test_belarusian_phonemizer.py +++ b/tests/text_tests/test_belarusian_phonemizer.py @@ -1,6 +1,6 @@ import os -import warnings import unittest +import warnings from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes @@ -17,7 +17,8 @@ class TestText(unittest.TestCase): except KeyError: warnings.warn( "You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file to test Belarusian phonemizer", - Warning) + Warning, + ) return for line in _TEST_CASES.strip().split("\n"):