diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py deleted file mode 100644 index cb540577..00000000 --- a/TTS/tts/layers/tortoise/dpm_solver.py +++ /dev/null @@ -1,1551 +0,0 @@ -import math - -import torch - - -class NoiseScheduleVP: - def __init__( - self, - schedule="discrete", - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20.0, - dtype=torch.float32, - ): - """Create a wrapper class for the forward SDE (VP type). - - *** - Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. - We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. - *** - - The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). - We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). - Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: - - log_alpha_t = self.marginal_log_mean_coeff(t) - sigma_t = self.marginal_std(t) - lambda_t = self.marginal_lambda(t) - - Moreover, as lambda(t) is an invertible function, we also support its inverse function: - - t = self.inverse_lambda(lambda_t) - - =============================================================== - - We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). - - 1. For discrete-time DPMs: - - For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: - t_i = (i + 1) / N - e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. - We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. - - Args: - betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) - alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) - - Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. - - **Important**: Please pay special attention for the args for `alphas_cumprod`: - The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that - q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). - Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have - alpha_{t_n} = \sqrt{\hat{alpha_n}}, - and - log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). - - - 2. For continuous-time DPMs: - - We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise - schedule are the default settings in DDPM and improved-DDPM: - - Args: - beta_min: A `float` number. The smallest beta for the linear schedule. - beta_max: A `float` number. The largest beta for the linear schedule. - cosine_s: A `float` number. The hyperparameter in the cosine schedule. - cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. - T: A `float` number. The ending time of the forward process. - - =============================================================== - - Args: - schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, - 'linear' or 'cosine' for continuous-time DPMs. - Returns: - A wrapper object of the forward SDE (VP type). - - =============================================================== - - Example: - - # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', betas=betas) - - # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) - - # For continuous-time DPMs (VPSDE), linear schedule: - >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) - - """ - - if schedule not in ["discrete", "linear", "cosine"]: - raise ValueError( - "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule - ) - ) - - self.schedule = schedule - if schedule == "discrete": - if betas is not None: - log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) - else: - assert alphas_cumprod is not None - log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) - self.T = 1.0 - self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) - self.log_alpha_array = log_alphas.reshape( - ( - 1, - -1, - ) - ).to(dtype=dtype) - else: - self.total_N = 1000 - self.beta_0 = continuous_beta_0 - self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999.0 - self.cosine_t_max = ( - math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) - * 2.0 - * (1.0 + self.cosine_s) - / math.pi - - self.cosine_s - ) - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) - self.schedule = schedule - if schedule == "cosine": - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1.0 - - def marginal_log_mean_coeff(self, t): - """ - Compute log(alpha_t) of a given continuous-time label t in [0, T]. - """ - if self.schedule == "discrete": - return interpolate_fn( - t.reshape((-1, 1)), - self.t_array.to(t.device), - self.log_alpha_array.to(t.device), - ).reshape((-1)) - elif self.schedule == "linear": - return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == "cosine": - - def log_alpha_fn(s): - return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) - - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t - - def marginal_alpha(self, t): - """ - Compute alpha_t of a given continuous-time label t in [0, T]. - """ - return torch.exp(self.marginal_log_mean_coeff(t)) - - def marginal_std(self, t): - """ - Compute sigma_t of a given continuous-time label t in [0, T]. - """ - return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) - - def marginal_lambda(self, t): - """ - Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. - """ - log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) - return log_mean_coeff - log_std - - def inverse_lambda(self, lamb): - """ - Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. - """ - if self.schedule == "linear": - tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0**2 + tmp - return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == "discrete": - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) - t = interpolate_fn( - log_alpha.reshape((-1, 1)), - torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1]), - ) - return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) - - def t_fn(log_alpha_t): - return ( - torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) - * 2.0 - * (1.0 + self.cosine_s) - / math.pi - - self.cosine_s - ) - - t = t_fn(log_alpha) - return t - - -def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1.0, - classifier_fn=None, - classifier_kwargs={}, -): - """Create a wrapper function for the noise prediction model. - - DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to - firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. - - We support four types of the diffusion model by setting `model_type`: - - 1. "noise": noise prediction model. (Trained by predicting noise). - - 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). - - 3. "v": velocity prediction model. (Trained by predicting the velocity). - The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. - - [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." - arXiv preprint arXiv:2202.00512 (2022). - [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." - arXiv preprint arXiv:2210.02303 (2022). - - 4. "score": marginal score function. (Trained by denoising score matching). - Note that the score function and the noise prediction model follows a simple relationship: - ``` - noise(x_t, t) = -sigma_t * score(x_t, t) - ``` - - We support three types of guided sampling by DPMs by setting `guidance_type`: - 1. "uncond": unconditional sampling by DPMs. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - The input `classifier_fn` has the following format: - `` - classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) - `` - - [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," - in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. - - 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. - The input `model` has the following format: - `` - model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score - `` - And if cond == `unconditional_condition`, the model output is the unconditional DPM output. - - [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." - arXiv preprint arXiv:2207.12598 (2022). - - - The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) - or continuous-time labels (i.e. epsilon to T). - - We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: - `` - def model_fn(x, t_continuous) -> noise: - t_input = get_model_input_time(t_continuous) - return noise_pred(model, x, t_input, **model_kwargs) - `` - where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. - - =============================================================== - - Args: - model: A diffusion model with the corresponding format described above. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - model_type: A `str`. The parameterization type of the diffusion model. - "noise" or "x_start" or "v" or "score". - model_kwargs: A `dict`. A dict for the other inputs of the model function. - guidance_type: A `str`. The type of the guidance for sampling. - "uncond" or "classifier" or "classifier-free". - condition: A pytorch tensor. The condition for the guided sampling. - Only used for "classifier" or "classifier-free" guidance type. - unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. - Only used for "classifier-free" guidance type. - guidance_scale: A `float`. The scale for the guided sampling. - classifier_fn: A classifier function. Only used for the classifier guidance. - classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. - Returns: - A noise prediction model that accepts the noised data and the continuous time as the inputs. - """ - - def get_model_input_time(t_continuous): - """ - Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. - For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. - For continuous-time DPMs, we just use `t_continuous`. - """ - if noise_schedule.schedule == "discrete": - return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 - else: - return t_continuous - - def noise_pred_fn(x, t_continuous, cond=None): - t_input = get_model_input_time(t_continuous) - if cond is None: - output = model(x, t_input, **model_kwargs) - else: - output = model(x, t_input, cond, **model_kwargs) - if model_type == "noise": - return output - elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - return (x - alpha_t * output) / sigma_t - elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - return alpha_t * output + sigma_t * x - elif model_type == "score": - sigma_t = noise_schedule.marginal_std(t_continuous) - return -sigma_t * output - - def cond_grad_fn(x, t_input): - """ - Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). - """ - with torch.enable_grad(): - x_in = x.detach().requires_grad_(True) - log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) - return torch.autograd.grad(log_prob.sum(), x_in)[0] - - def model_fn(x, t_continuous): - """ - The noise predicition model function that is used for DPM-Solver. - """ - if guidance_type == "uncond": - return noise_pred_fn(x, t_continuous) - elif guidance_type == "classifier": - assert classifier_fn is not None - t_input = get_model_input_time(t_continuous) - cond_grad = cond_grad_fn(x, t_input) - sigma_t = noise_schedule.marginal_std(t_continuous) - noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * sigma_t * cond_grad - elif guidance_type == "classifier-free": - if guidance_scale == 1.0 or unconditional_condition is None: - return noise_pred_fn(x, t_continuous, cond=condition) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - return noise_uncond + guidance_scale * (noise - noise_uncond) - - assert model_type in ["noise", "x_start", "v", "score"] - assert guidance_type in ["uncond", "classifier", "classifier-free"] - return model_fn - - -class DPM_Solver: - def __init__( - self, - model_fn, - noise_schedule, - algorithm_type="dpmsolver++", - correcting_x0_fn=None, - correcting_xt_fn=None, - thresholding_max_val=1.0, - dynamic_thresholding_ratio=0.995, - ): - """Construct a DPM-Solver. - - We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). - - We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you - can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the - dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space - DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space - DPMs (such as stable-diffusion). - - To support advanced algorithms in image-to-image applications, we also support corrector functions for - both x0 and xt. - - Args: - model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): - `` - def model_fn(x, t_continuous): - return noise - `` - The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". - correcting_x0_fn: A `str` or a function with the following format: - ``` - def correcting_x0_fn(x0, t): - x0_new = ... - return x0_new - ``` - This function is to correct the outputs of the data prediction model at each sampling step. e.g., - ``` - x0_pred = data_pred_model(xt, t) - if correcting_x0_fn is not None: - x0_pred = correcting_x0_fn(x0_pred, t) - xt_1 = update(x0_pred, xt, t) - ``` - If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. - correcting_xt_fn: A function with the following format: - ``` - def correcting_xt_fn(xt, t, step): - x_new = ... - return x_new - ``` - This function is to correct the intermediate samples xt at each sampling step. e.g., - ``` - xt = ... - xt = correcting_xt_fn(xt, t, step) - ``` - thresholding_max_val: A `float`. The max value for thresholding. - Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. - dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). - Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. - - [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, - Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models - with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. - """ - self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) - self.noise_schedule = noise_schedule - assert algorithm_type in ["dpmsolver", "dpmsolver++"] - self.algorithm_type = algorithm_type - if correcting_x0_fn == "dynamic_thresholding": - self.correcting_x0_fn = self.dynamic_thresholding_fn - else: - self.correcting_x0_fn = correcting_x0_fn - self.correcting_xt_fn = correcting_xt_fn - self.dynamic_thresholding_ratio = dynamic_thresholding_ratio - self.thresholding_max_val = thresholding_max_val - - def dynamic_thresholding_fn(self, x0, t): - """ - The dynamic thresholding method. - """ - dims = x0.dim() - p = self.dynamic_thresholding_ratio - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims( - torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), - dims, - ) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def noise_prediction_fn(self, x, t): - """ - Return the noise prediction model. - """ - return self.model(x, t) - - def data_prediction_fn(self, x, t): - """ - Return the data prediction model (with corrector). - """ - noise = self.noise_prediction_fn(x, t) - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - sigma_t * noise) / alpha_t - if self.correcting_x0_fn is not None: - x0 = self.correcting_x0_fn(x0, t) - return x0 - - def model_fn(self, x, t): - """ - Convert the model to the noise prediction model or the data prediction model. - """ - if self.algorithm_type == "dpmsolver++": - return self.data_prediction_fn(x, t) - else: - return self.noise_prediction_fn(x, t) - - def get_time_steps(self, skip_type, t_T, t_0, N, device): - """Compute the intermediate time steps for sampling. - - Args: - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - N: A `int`. The total number of the spacing of the time steps. - device: A torch device. - Returns: - A pytorch tensor of the time steps, with the shape (N + 1,). - """ - if skip_type == "logSNR": - lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) - lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) - return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == "time_uniform": - return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == "time_quadratic": - t_order = 2 - t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) - return t - else: - raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) - ) - - def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): - """ - Get the order of each step for sampling by the singlestep DPM-Solver. - - We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". - Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: - - If order == 1: - We take `steps` of DPM-Solver-1 (i.e. DDIM). - - If order == 2: - - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of DPM-Solver-2. - - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If order == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. - - ============================================ - Args: - order: A `int`. The max order for the solver (2 or 3). - steps: A `int`. The total number of function evaluations (NFE). - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - device: A torch device. - Returns: - orders: A list of the solver order of each step. - """ - if order == 3: - K = steps // 3 + 1 - if steps % 3 == 0: - orders = [3,] * ( - K - 2 - ) + [2, 1] - elif steps % 3 == 1: - orders = [3,] * ( - K - 1 - ) + [1] - else: - orders = [3,] * ( - K - 1 - ) + [2] - elif order == 2: - if steps % 2 == 0: - K = steps // 2 - orders = [ - 2, - ] * K - else: - K = steps // 2 + 1 - orders = [2,] * ( - K - 1 - ) + [1] - elif order == 1: - K = 1 - orders = [ - 1, - ] * steps - else: - raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == "logSNR": - # To reproduce the results in DPM-Solver paper - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) - else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum( - torch.tensor( - [ - 0, - ] - + orders - ), - 0, - ).to(device) - ] - return timesteps_outer, orders - - def denoise_to_zero_fn(self, x, s): - """ - Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. - """ - return self.data_prediction_fn(x, s) - - def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): - """ - DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (1,). - t: A pytorch tensor. The ending time, with the shape (1,). - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - if self.algorithm_type == "dpmsolver++": - phi_1 = torch.expm1(-h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s - if return_intermediate: - return x_t, {"model_s": model_s} - else: - return x_t - else: - phi_1 = torch.expm1(h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s - if return_intermediate: - return x_t, {"model_s": model_s} - else: - return x_t - - def singlestep_dpm_solver_second_update( - self, - x, - s, - t, - r1=0.5, - model_s=None, - return_intermediate=False, - solver_type="dpmsolver", - ): - """ - Singlestep solver DPM-Solver-2 from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (1,). - t: A pytorch tensor. The ending time, with the shape (1,). - r1: A `float`. The hyperparameter of the second-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 0.5 - ns = self.noise_schedule - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ( - ns.marginal_log_mean_coeff(s), - ns.marginal_log_mean_coeff(s1), - ns.marginal_log_mean_coeff(t), - ) - sigma_s, sigma_s1, sigma_t = ( - ns.marginal_std(s), - ns.marginal_std(s1), - ns.marginal_std(t), - ) - alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) - - if self.algorithm_type == "dpmsolver++": - phi_11 = torch.expm1(-r1 * h) - phi_1 = torch.expm1(-h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s - model_s1 = self.model_fn(x_s1, s1) - if solver_type == "dpmsolver": - x_t = ( - (sigma_t / sigma_s) * x - - (alpha_t * phi_1) * model_s - - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) - ) - elif solver_type == "taylor": - x_t = ( - (sigma_t / sigma_s) * x - - (alpha_t * phi_1) * model_s - + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_1 = torch.expm1(h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s - model_s1 = self.model_fn(x_s1, s1) - if solver_type == "dpmsolver": - x_t = ( - torch.exp(log_alpha_t - log_alpha_s) * x - - (sigma_t * phi_1) * model_s - - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) - ) - elif solver_type == "taylor": - x_t = ( - torch.exp(log_alpha_t - log_alpha_s) * x - - (sigma_t * phi_1) * model_s - - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) - ) - if return_intermediate: - return x_t, {"model_s": model_s, "model_s1": model_s1} - else: - return x_t - - def singlestep_dpm_solver_third_update( - self, - x, - s, - t, - r1=1.0 / 3.0, - r2=2.0 / 3.0, - model_s=None, - model_s1=None, - return_intermediate=False, - solver_type="dpmsolver", - ): - """ - Singlestep solver DPM-Solver-3 from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (1,). - t: A pytorch tensor. The ending time, with the shape (1,). - r1: A `float`. The hyperparameter of the third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). - If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 1.0 / 3.0 - if r2 is None: - r2 = 2.0 / 3.0 - ns = self.noise_schedule - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - lambda_s2 = lambda_s + r2 * h - s1 = ns.inverse_lambda(lambda_s1) - s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( - ns.marginal_log_mean_coeff(s), - ns.marginal_log_mean_coeff(s1), - ns.marginal_log_mean_coeff(s2), - ns.marginal_log_mean_coeff(t), - ) - sigma_s, sigma_s1, sigma_s2, sigma_t = ( - ns.marginal_std(s), - ns.marginal_std(s1), - ns.marginal_std(s2), - ns.marginal_std(t), - ) - alpha_s1, alpha_s2, alpha_t = ( - torch.exp(log_alpha_s1), - torch.exp(log_alpha_s2), - torch.exp(log_alpha_t), - ) - - if self.algorithm_type == "dpmsolver++": - phi_11 = torch.expm1(-r1 * h) - phi_12 = torch.expm1(-r2 * h) - phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 - phi_2 = phi_1 / h + 1.0 - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - (sigma_s2 / sigma_s) * x - - (alpha_s2 * phi_12) * model_s - + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == "dpmsolver": - x_t = ( - (sigma_t / sigma_s) * x - - (alpha_t * phi_1) * model_s - + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) - ) - elif solver_type == "taylor": - D1_0 = (1.0 / r1) * (model_s1 - model_s) - D1_1 = (1.0 / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - (sigma_t / sigma_s) * x - - (alpha_t * phi_1) * model_s - + (alpha_t * phi_2) * D1 - - (alpha_t * phi_3) * D2 - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_12 = torch.expm1(r2 * h) - phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 - phi_2 = phi_1 / h - 1.0 - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - (torch.exp(log_alpha_s2 - log_alpha_s)) * x - - (sigma_s2 * phi_12) * model_s - - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == "dpmsolver": - x_t = ( - (torch.exp(log_alpha_t - log_alpha_s)) * x - - (sigma_t * phi_1) * model_s - - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) - ) - elif solver_type == "taylor": - D1_0 = (1.0 / r1) * (model_s1 - model_s) - D1_1 = (1.0 / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - (torch.exp(log_alpha_t - log_alpha_s)) * x - - (sigma_t * phi_1) * model_s - - (sigma_t * phi_2) * D1 - - (sigma_t * phi_3) * D2 - ) - - if return_intermediate: - return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} - else: - return x_t - - def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): - """ - Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) - t: A pytorch tensor. The ending time, with the shape (1,). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) - ns = self.noise_schedule - model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] - t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] - lambda_prev_1, lambda_prev_0, lambda_t = ( - ns.marginal_lambda(t_prev_1), - ns.marginal_lambda(t_prev_0), - ns.marginal_lambda(t), - ) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0 = h_0 / h - D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) - if self.algorithm_type == "dpmsolver++": - phi_1 = torch.expm1(-h) - if solver_type == "dpmsolver": - x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 - elif solver_type == "taylor": - x_t = ( - (sigma_t / sigma_prev_0) * x - - (alpha_t * phi_1) * model_prev_0 - + (alpha_t * (phi_1 / h + 1.0)) * D1_0 - ) - else: - phi_1 = torch.expm1(h) - if solver_type == "dpmsolver": - x_t = ( - (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - - (sigma_t * phi_1) * model_prev_0 - - 0.5 * (sigma_t * phi_1) * D1_0 - ) - elif solver_type == "taylor": - x_t = ( - (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - - (sigma_t * phi_1) * model_prev_0 - - (sigma_t * (phi_1 / h - 1.0)) * D1_0 - ) - return x_t - - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): - """ - Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) - t: A pytorch tensor. The ending time, with the shape (1,). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - model_prev_2, model_prev_1, model_prev_0 = model_prev_list - t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( - ns.marginal_lambda(t_prev_2), - ns.marginal_lambda(t_prev_1), - ns.marginal_lambda(t_prev_0), - ns.marginal_lambda(t), - ) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_1 = lambda_prev_1 - lambda_prev_2 - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0, r1 = h_0 / h, h_1 / h - D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) - D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.algorithm_type == "dpmsolver++": - phi_1 = torch.expm1(-h) - phi_2 = phi_1 / h + 1.0 - phi_3 = phi_2 / h - 0.5 - x_t = ( - (sigma_t / sigma_prev_0) * x - - (alpha_t * phi_1) * model_prev_0 - + (alpha_t * phi_2) * D1 - - (alpha_t * phi_3) * D2 - ) - else: - phi_1 = torch.expm1(h) - phi_2 = phi_1 / h - 1.0 - phi_3 = phi_2 / h - 0.5 - x_t = ( - (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - - (sigma_t * phi_1) * model_prev_0 - - (sigma_t * phi_2) * D1 - - (sigma_t * phi_3) * D2 - ) - return x_t - - def singlestep_dpm_solver_update( - self, - x, - s, - t, - order, - return_intermediate=False, - solver_type="dpmsolver", - r1=None, - r2=None, - ): - """ - Singlestep DPM-Solver with the order `order` from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (1,). - t: A pytorch tensor. The ending time, with the shape (1,). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - r1: A `float`. The hyperparameter of the second-order or third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) - elif order == 2: - return self.singlestep_dpm_solver_second_update( - x, - s, - t, - return_intermediate=return_intermediate, - solver_type=solver_type, - r1=r1, - ) - elif order == 3: - return self.singlestep_dpm_solver_third_update( - x, - s, - t, - return_intermediate=return_intermediate, - solver_type=solver_type, - r1=r1, - r2=r2, - ) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): - """ - Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) - t: A pytorch tensor. The ending time, with the shape (1,). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) - elif order == 2: - return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - elif order == 3: - return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def dpm_solver_adaptive( - self, - x, - order, - t_T, - t_0, - h_init=0.05, - atol=0.0078, - rtol=0.05, - theta=0.9, - t_err=1e-5, - solver_type="dpmsolver", - ): - """ - The adaptive step size solver based on singlestep DPM-Solver. - - Args: - x: A pytorch tensor. The initial value at time `t_T`. - order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - h_init: A `float`. The initial step size (for logSNR). - atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. - rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. - theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. - t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the - current time and `t_0` is less than `t_err`. The default setting is 1e-5. - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_0: A pytorch tensor. The approximated solution at time `t_0`. - - [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. - """ - ns = self.noise_schedule - s = t_T * torch.ones((1,)).to(x) - lambda_s = ns.marginal_lambda(s) - lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) - h = h_init * torch.ones_like(s).to(x) - x_prev = x - nfe = 0 - if order == 2: - r1 = 0.5 - - def lower_update(x, s, t): - return self.dpm_solver_first_update(x, s, t, return_intermediate=True) - - def higher_update(x, s, t, **kwargs): - return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) - - elif order == 3: - r1, r2 = 1.0 / 3.0, 2.0 / 3.0 - - def lower_update(x, s, t): - return self.singlestep_dpm_solver_second_update( - x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type - ) - - def higher_update(x, s, t, **kwargs): - return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) - - else: - raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) - while torch.abs((s - t_0)).mean() > t_err: - t = ns.inverse_lambda(lambda_s + h) - x_lower, lower_noise_kwargs = lower_update(x, s, t) - x_higher = higher_update(x, s, t, **lower_noise_kwargs) - delta = torch.max( - torch.ones_like(x).to(x) * atol, - rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)), - ) - - def norm_fn(v): - return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) - - E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.0): - x = x_higher - s = t - x_prev = x_lower - lambda_s = ns.marginal_lambda(s) - h = torch.min( - theta * h * torch.float_power(E, -1.0 / order).float(), - lambda_0 - lambda_s, - ) - nfe += order - print("adaptive solver nfe", nfe) - return x - - def add_noise(self, x, t, noise=None): - """ - Compute the noised input xt = alpha_t * x + sigma_t * noise. - - Args: - x: A `torch.Tensor` with shape `(batch_size, *shape)`. - t: A `torch.Tensor` with shape `(t_size,)`. - Returns: - xt with shape `(t_size, batch_size, *shape)`. - """ - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - if noise is None: - noise = torch.randn((t.shape[0], *x.shape), device=x.device) - x = x.reshape((-1, *x.shape)) - xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise - if t.shape[0] == 1: - return xt.squeeze(0) - else: - return xt - - def inverse( - self, - x, - steps=20, - t_start=None, - t_end=None, - order=2, - skip_type="time_uniform", - method="multistep", - lower_order_final=True, - denoise_to_zero=False, - solver_type="dpmsolver", - atol=0.0078, - rtol=0.05, - return_intermediate=False, - ): - """ - Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. - For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. - """ - t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start - t_T = self.noise_schedule.T if t_end is None else t_end - assert ( - t_0 > 0 and t_T > 0 - ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" - return self.sample( - x, - steps=steps, - t_start=t_0, - t_end=t_T, - order=order, - skip_type=skip_type, - method=method, - lower_order_final=lower_order_final, - denoise_to_zero=denoise_to_zero, - solver_type=solver_type, - atol=atol, - rtol=rtol, - return_intermediate=return_intermediate, - ) - - def sample( - self, - x, - steps=20, - t_start=None, - t_end=None, - order=2, - skip_type="time_uniform", - method="multistep", - lower_order_final=True, - denoise_to_zero=False, - solver_type="dpmsolver", - atol=0.0078, - rtol=0.05, - return_intermediate=False, - ): - """ - Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. - - ===================================================== - - We support the following algorithms for both noise prediction model and data prediction model: - - 'singlestep': - Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. - We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). - The total number of function evaluations (NFE) == `steps`. - Given a fixed NFE == `steps`, the sampling procedure is: - - If `order` == 1: - - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. - - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If `order` == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. - - 'multistep': - Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. - We initialize the first `order` values by lower order multistep solvers. - Given a fixed NFE == `steps`, the sampling procedure is: - Denote K = steps. - - If `order` == 1: - - We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. - - If `order` == 3: - - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. - - 'singlestep_fixed': - Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). - We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. - - 'adaptive': - Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). - We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. - You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs - (NFE) and the sample quality. - - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. - - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. - - ===================================================== - - Some advices for choosing the algorithm: - - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: - Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. - e.g., DPM-Solver: - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - e.g., DPM-Solver++: - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - - For **guided sampling with large guidance scale** by DPMs: - Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, - skip_type='time_uniform', method='multistep') - - We support three types of `skip_type`: - - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** - - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. - - 'time_quadratic': quadratic time for the time steps. - - ===================================================== - Args: - x: A pytorch tensor. The initial value at time `t_start` - e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. - steps: A `int`. The total number of function evaluations (NFE). - t_start: A `float`. The starting time of the sampling. - If `T` is None, we use self.noise_schedule.T (default is 1.0). - t_end: A `float`. The ending time of the sampling. - If `t_end` is None, we use 1. / self.noise_schedule.total_N. - e.g. if total_N == 1000, we have `t_end` == 1e-3. - For discrete-time DPMs: - - We recommend `t_end` == 1. / self.noise_schedule.total_N. - For continuous-time DPMs: - - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. - order: A `int`. The order of DPM-Solver. - skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. - method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. - Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). - - This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and - score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID - for diffusion models sampling by diffusion SDEs for low-resolutional images - (such as CIFAR-10). However, we observed that such trick does not matter for - high-resolutional images. As it needs an additional NFE, we do not recommend - it for high-resolutional images. - lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. - Only valid for `method=multistep` and `steps < 15`. We empirically find that - this trick is a key to stabilizing the sampling by DPM-Solver with very few steps - (especially for steps <= 10). So we recommend to set it to be `True`. - solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. - atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - return_intermediate: A `bool`. Whether to save the xt at each step. - When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. - Returns: - x_end: A pytorch tensor. The approximated solution at time `t_end`. - - """ - t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start - assert ( - t_0 > 0 and t_T > 0 - ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" - if return_intermediate: - assert method in [ - "multistep", - "singlestep", - "singlestep_fixed", - ], "Cannot use adaptive solver when saving intermediate values" - if self.correcting_xt_fn is not None: - assert method in [ - "multistep", - "singlestep", - "singlestep_fixed", - ], "Cannot use adaptive solver when correcting_xt_fn is not None" - device = x.device - intermediates = [] - with torch.no_grad(): - if method == "adaptive": - x = self.dpm_solver_adaptive( - x, - order=order, - t_T=t_T, - t_0=t_0, - atol=atol, - rtol=rtol, - solver_type=solver_type, - ) - elif method == "multistep": - assert steps >= order - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert timesteps.shape[0] - 1 == steps - # Init the initial values. - step = 0 - t = timesteps[step] - t_prev_list = [t] - model_prev_list = [self.model_fn(x, t)] - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step) - if return_intermediate: - intermediates.append(x) - # Init the first `order` values by lower order multistep DPM-Solver. - for step in range(1, order): - t = timesteps[step] - x = self.multistep_dpm_solver_update( - x, - model_prev_list, - t_prev_list, - t, - step, - solver_type=solver_type, - ) - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step) - if return_intermediate: - intermediates.append(x) - t_prev_list.append(t) - model_prev_list.append(self.model_fn(x, t)) - # Compute the remaining values by `order`-th order multistep DPM-Solver. - for step in range(order, steps + 1): - t = timesteps[step] - # We only use lower order for steps < 10 - if lower_order_final and steps < 10: - step_order = min(order, steps + 1 - step) - else: - step_order = order - x = self.multistep_dpm_solver_update( - x, - model_prev_list, - t_prev_list, - t, - step_order, - solver_type=solver_type, - ) - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step) - if return_intermediate: - intermediates.append(x) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = t - # We do not need to evaluate the final model value. - if step < steps: - model_prev_list[-1] = self.model_fn(x, t) - elif method in ["singlestep", "singlestep_fixed"]: - if method == "singlestep": - (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver( - steps=steps, - order=order, - skip_type=skip_type, - t_T=t_T, - t_0=t_0, - device=device, - ) - elif method == "singlestep_fixed": - K = steps // order - orders = [ - order, - ] * K - timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) - for step, order in enumerate(orders): - s, t = timesteps_outer[step], timesteps_outer[step + 1] - timesteps_inner = self.get_time_steps( - skip_type=skip_type, - t_T=s.item(), - t_0=t.item(), - N=order, - device=device, - ) - lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) - h = lambda_inner[-1] - lambda_inner[0] - r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h - r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step) - if return_intermediate: - intermediates.append(x) - else: - raise ValueError("Got wrong method {}".format(method)) - if denoise_to_zero: - t = torch.ones((1,)).to(device) * t_0 - x = self.denoise_to_zero_fn(x, t) - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step + 1) - if return_intermediate: - intermediates.append(x) - if return_intermediate: - return x, intermediates - else: - return x - - -############################################################# -# other utility functions -############################################################# - - -def interpolate_fn(x, xp, yp): - """ - A piecewise linear function y = f(x), using xp and yp as keypoints. - We implement f(x) in a differentiable way (i.e. applicable for autograd). - The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) - - Args: - x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). - xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. - yp: PyTorch tensor with shape [C, K]. - Returns: - The function values f(x), with shape [N, C]. - """ - N, K = x.shape[0], xp.shape[1] - all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) - sorted_all_x, x_indices = torch.sort(all_x, dim=2) - x_idx = torch.argmin(x_indices, dim=2) - cand_start_idx = x_idx - 1 - start_idx = torch.where( - torch.eq(x_idx, 0), - torch.tensor(1, device=x.device), - torch.where( - torch.eq(x_idx, K), - torch.tensor(K - 2, device=x.device), - cand_start_idx, - ), - ) - end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) - start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) - end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) - start_idx2 = torch.where( - torch.eq(x_idx, 0), - torch.tensor(0, device=x.device), - torch.where( - torch.eq(x_idx, K), - torch.tensor(K - 2, device=x.device), - cand_start_idx, - ), - ) - y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) - end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand - - -def expand_dims(v, dims): - """ - Expand the tensor `v` to the dim `dims`. - - Args: - `v`: a PyTorch tensor with shape [N]. - `dim`: a `int`. - Returns: - a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. - """ - return v[(...,) + (None,) * (dims - 1)] diff --git a/TTS/tts/layers/xtts/diffusion.py b/TTS/tts/layers/xtts/diffusion.py deleted file mode 100644 index 37665bc6..00000000 --- a/TTS/tts/layers/xtts/diffusion.py +++ /dev/null @@ -1,1319 +0,0 @@ -import enum -import math - -import numpy as np -import torch -import torch as th -from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral -from tqdm import tqdm - -from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper - -K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} -SAMPLERS = ["dpm++2m", "p", "ddim"] - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - Compute the KL divergence between two gaussians. - - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, th.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for th.exp(). - logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] - - return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) - - -def approx_standard_normal_cdf(x): - """ - A fast approximation of the cumulative distribution function of the - standard normal. - """ - return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) - - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - """ - Compute the log-likelihood of a Gaussian distribution discretizing to a - given image. - - :param x: the target images. It is assumed that this was uint8 values, - rescaled to the range [-1, 1]. - :param means: the Gaussian mean Tensor. - :param log_scales: the Gaussian log stddev Tensor. - :return: a tensor like x of log probabilities (in nats). - """ - assert x.shape == means.shape == log_scales.shape - centered_x = x - means - inv_stdv = th.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 1.0 / 255.0) - cdf_plus = approx_standard_normal_cdf(plus_in) - min_in = inv_stdv * (centered_x - 1.0 / 255.0) - cdf_min = approx_standard_normal_cdf(min_in) - log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) - log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) - cdf_delta = cdf_plus - cdf_min - log_probs = th.where( - x < -0.999, - log_cdf_plus, - th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), - ) - assert log_probs.shape == x.shape - return log_probs - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): - """ - Get a pre-defined beta schedule for the given name. - - The beta schedule library consists of beta schedules which remain similar - in the limit of num_diffusion_timesteps. - Beta schedules may be added, but should not be removed or changed once - they are committed to maintain backwards compatibility. - """ - if schedule_name == "linear": - # Linear schedule from Ho et al, extended to work for any number of - # diffusion steps. - scale = 1000 / num_diffusion_timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) - elif schedule_name == "cosine": - return betas_for_alpha_bar( - num_diffusion_timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) - else: - raise NotImplementedError(f"unknown beta schedule: {schedule_name}") - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -class ModelMeanType(enum.Enum): - """ - Which type of output the model predicts. - """ - - PREVIOUS_X = "previous_x" # the model predicts x_{t-1} - START_X = "start_x" # the model predicts x_0 - EPSILON = "epsilon" # the model predicts epsilon - - -class ModelVarType(enum.Enum): - """ - What is used as the model's output variance. - - The LEARNED_RANGE option has been added to allow the model to predict - values between FIXED_SMALL and FIXED_LARGE, making its job easier. - """ - - LEARNED = "learned" - FIXED_SMALL = "fixed_small" - FIXED_LARGE = "fixed_large" - LEARNED_RANGE = "learned_range" - - -class LossType(enum.Enum): - MSE = "mse" # use raw MSE loss (and KL when learning variances) - RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances) - KL = "kl" # use the variational lower-bound - RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB - - def is_vb(self): - return self == LossType.KL or self == LossType.RESCALED_KL - - -class GaussianDiffusion: - """ - Utilities for training and sampling diffusion models. - - Ported directly from here, and then adapted over time to further experimentation. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 - - :param betas: a 1-D numpy array of betas for each diffusion timestep, - starting at T and going to 1. - :param model_mean_type: a ModelMeanType determining what the model outputs. - :param model_var_type: a ModelVarType determining how variance is output. - :param loss_type: a LossType determining the loss function to use. - :param rescale_timesteps: if True, pass floating point timesteps into the - model so that they are always scaled like in the - original paper (0 to 1000). - """ - - def __init__( - self, - *, - betas, - model_mean_type, - model_var_type, - loss_type, - rescale_timesteps=False, # this is generally False - conditioning_free=False, - conditioning_free_k=1, - ramp_conditioning_free=True, - sampler="ddim", - ): - self.sampler = sampler - self.model_mean_type = ModelMeanType(model_mean_type) - self.model_var_type = ModelVarType(model_var_type) - self.loss_type = LossType(loss_type) - self.rescale_timesteps = rescale_timesteps - self.conditioning_free = conditioning_free - self.conditioning_free_k = conditioning_free_k - self.ramp_conditioning_free = ramp_conditioning_free - - # Use float64 for accuracy. - betas = np.array(betas, dtype=np.float64) - self.betas = betas - assert len(betas.shape) == 1, "betas must be 1-D" - assert (betas > 0).all() and (betas <= 1).all() - - self.num_timesteps = int(betas.shape[0]) - - alphas = 1.0 - betas - self.alphas_cumprod = np.cumprod(alphas, axis=0) - self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) - self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) - assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) - self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) - self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) - self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) - self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - # log calculation clipped because the posterior variance is 0 at the - # beginning of the diffusion chain. - self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) - self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data for a given number of diffusion steps. - - In other words, sample from q(x_t | x_0). - - :param x_start: the initial data batch. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :param noise: if specified, the split-out normal noise. - :return: A noisy version of x_start. - """ - if noise is None: - noise = th.randn_like(x_start) - assert noise.shape == x_start.shape - return ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior: - - q(x_{t-1} | x_t, x_0) - - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) - assert ( - posterior_mean.shape[0] - == posterior_variance.shape[0] - == posterior_log_variance_clipped.shape[0] - == x_start.shape[0] - ) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): - """ - Apply the model to get p(x_{t-1} | x_t), as well as a prediction of - the initial x, x_0. - - :param model: the model, which takes a signal and a batch of timesteps - as input. - :param x: the [N x C x ...] tensor at time t. - :param t: a 1-D Tensor of timesteps. - :param clip_denoised: if True, clip the denoised signal into [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. Applies before - clip_denoised. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :return: a dict with the following keys: - - 'mean': the model mean output. - - 'variance': the model variance output. - - 'log_variance': the log of 'variance'. - - 'pred_xstart': the prediction for x_0. - """ - if model_kwargs is None: - model_kwargs = {} - - assert self.model_var_type == ModelVarType.LEARNED_RANGE - assert self.model_mean_type == ModelMeanType.EPSILON - assert denoised_fn is None - assert clip_denoised is True - B, C = x.shape[:2] - assert t.shape == (B,) - model_output = model(x, self._scale_timesteps(t), **model_kwargs) - if self.conditioning_free: - model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) - - if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: - assert model_output.shape == (B, C * 2, *x.shape[2:]) - model_output, model_var_values = th.split(model_output, C, dim=1) - if self.conditioning_free: - model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) - if self.model_var_type == ModelVarType.LEARNED: - assert False - model_log_variance = model_var_values - model_variance = th.exp(model_log_variance) - else: - min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) - max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) - # The model_var_values is [-1, 1] for [min_var, max_var]. - frac = (model_var_values + 1) / 2 - model_log_variance = frac * max_log + (1 - frac) * min_log - model_variance = th.exp(model_log_variance) - else: - assert False - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so - # to get a better decoder log likelihood. - ModelVarType.FIXED_LARGE: ( - np.append(self.posterior_variance[1], self.betas[1:]), - np.log(np.append(self.posterior_variance[1], self.betas[1:])), - ), - ModelVarType.FIXED_SMALL: ( - self.posterior_variance, - self.posterior_log_variance_clipped, - ), - }[self.model_var_type] - model_variance = _extract_into_tensor(model_variance, t, x.shape) - model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) - - if self.conditioning_free: - if self.ramp_conditioning_free: - assert t.shape[0] == 1 # This should only be used in inference. - cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) - else: - cfk = self.conditioning_free_k - model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning - - def process_xstart(x): - if denoised_fn is not None: - assert False - x = denoised_fn(x) - if clip_denoised: - return x.clamp(-1, 1) - assert False - return x - - if self.model_mean_type == ModelMeanType.PREVIOUS_X: - assert False - pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) - model_mean = model_output - elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: - if self.model_mean_type == ModelMeanType.START_X: - assert False - pred_xstart = process_xstart(model_output) - else: - pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) - model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) - else: - raise NotImplementedError(self.model_mean_type) - - assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape - return { - "mean": model_mean, - "variance": model_variance, - "log_variance": model_log_variance, - "pred_xstart": pred_xstart, - } - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps - ) - - def _predict_xstart_from_xprev(self, x_t, t, xprev): - assert x_t.shape == xprev.shape - return ( # (xprev - coef2*x_t) / coef1 - _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t - ) - - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - - def _scale_timesteps(self, t): - if self.rescale_timesteps: - return t.float() * (1000.0 / self.num_timesteps) - return t - - def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): - """ - Compute the mean for the previous step, given a function cond_fn that - computes the gradient of a conditional log probability with respect to - x. In particular, cond_fn computes grad(log(p(y|x))), and we want to - condition on y. - - This uses the conditioning strategy from Sohl-Dickstein et al. (2015). - """ - gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) - new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() - return new_mean - - def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): - """ - Compute what the p_mean_variance output would have been, should the - model's score function be conditioned by cond_fn. - - See condition_mean() for details on cond_fn. - - Unlike condition_mean(), this instead uses the conditioning strategy - from Song et al (2020). - """ - alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) - - eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) - eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs) - - out = p_mean_var.copy() - out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) - out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) - return out - - def p_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - ): - """ - Sample x_{t-1} from the model at the given timestep. - - :param model: the model to sample from. - :param x: the current tensor at x_{t-1}. - :param t: the value of t, starting at 0 for the first diffusion step. - :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. - :param cond_fn: if not None, this is a gradient function that acts - similarly to the model. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :return: a dict containing the following keys: - - 'sample': a random sample from the model. - - 'pred_xstart': a prediction of x_0. - """ - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - noise = th.randn_like(x) - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 - if cond_fn is not None: - out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) - sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise - return {"sample": sample, "pred_xstart": out["pred_xstart"]} - - def k_diffusion_sample_loop( - self, - k_sampler, - pbar, - model, - shape, - noise=None, # all given - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - device=None, # ALL UNUSED - model_kwargs=None, # {'precomputed_aligned_embeddings': precomputed_embeddings}, - progress=False, # unused as well - ): - assert isinstance(model_kwargs, dict) - if device is None: - device = next(model.parameters()).device - s_in = noise.new_ones([noise.shape[0]]) - - def model_split(*args, **kwargs): - model_output = model(*args, **kwargs) - model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1) - return model_epsilon, model_var - - # - """ - print(self.betas) - print(th.tensor(self.betas)) - noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) - """ - noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4) - - def model_fn_prewrap(x, t, *args, **kwargs): - """ - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - print(t) - print(self.timestep_map) - exit() - """ - """ - model_output = model(x, self._scale_timesteps(t*4000), **model_kwargs) - out = self.p_mean_variance(model, x, t*4000, model_kwargs=model_kwargs) - return out['pred_xstart'] - """ - x, _ = x.chunk(2) - t, _ = (t * 1000).chunk(2) - res = torch.cat( - [ - model_split(x, t, conditioning_free=True, **model_kwargs)[0], - model_split(x, t, **model_kwargs)[0], - ] - ) - pbar.update(1) - return res - - model_fn = model_wrapper( - model_fn_prewrap, - noise_schedule, - model_type="noise", # "noise" or "x_start" or "v" or "score" - model_kwargs=model_kwargs, - guidance_type="classifier-free", - condition=th.Tensor(1), - unconditional_condition=th.Tensor(1), - guidance_scale=self.conditioning_free_k, - ) - """ - model_fn = model_wrapper( - model_fn_prewrap, - noise_schedule, - model_type='x_start', - model_kwargs={} - ) - # - dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") - x_sample = dpm_solver.sample( - noise, - steps=20, - order=3, - skip_type="time_uniform", - method="singlestep", - ) - """ - dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") - x_sample = dpm_solver.sample( - noise, - steps=self.num_timesteps, - order=2, - skip_type="time_uniform", - method="multistep", - ) - #''' - return x_sample - - # HF DIFFUSION ATTEMPT - """ - from .hf_diffusion import EulerAncestralDiscreteScheduler - Scheduler = EulerAncestralDiscreteScheduler() - Scheduler.set_timesteps(100) - for timestep in Scheduler.timesteps: - noise_input = Scheduler.scale_model_input(noise, timestep) - ts = s_in * timestep - model_output = model(noise_input, ts, **model_kwargs) - model_epsilon, _model_var = th.split(model_output, model_output.shape[1]//2, dim=1) - noise, _x0 = Scheduler.step(model_epsilon, timestep, noise) - return noise - """ - - # KARRAS DIFFUSION ATTEMPT - """ - TRAINED_DIFFUSION_STEPS = 4000 # HARDCODED - ratio = TRAINED_DIFFUSION_STEPS/14.5 - def call_model(*args, **kwargs): - model_output = model(*args, **kwargs) - model_output, model_var_values = th.split(model_output, model_output.shape[1]//2, dim=1) - return model_output - print(get_sigmas_karras(self.num_timesteps, sigma_min=0.0, sigma_max=4000, device=device)) - exit() - sigmas = get_sigmas_karras(self.num_timesteps, sigma_min=0.03, sigma_max=14.5, device=device) - return k_sampler(call_model, noise, sigmas, extra_args=model_kwargs, disable=not progress) - ''' - sigmas = get_sigmas_karras(self.num_timesteps, sigma_min=0.03, sigma_max=14.5, device=device) - step = 0 # LMAO - global_sigmas = None - # - def fakemodel(x, t, **model_kwargs): - print(t,global_sigmas*ratio) - return model(x, t, **model_kwargs) - def denoised(x, sigmas, **extra_args): - t = th.tensor([self.num_timesteps-step-1] * shape[0], device=device) - nonlocal global_sigmas - global_sigmas = sigmas - with th.no_grad(): - out = self.p_sample( - fakemodel, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - ) - return out["sample"] - def callback(d): - nonlocal step - step += 1 - - return k_sampler(denoised, noise, sigmas, extra_args=model_kwargs, callback=callback, disable=not progress) - ''' - """ - - def sample_loop(self, *args, **kwargs): - s = self.sampler - if s == "p": - return self.p_sample_loop(*args, **kwargs) - elif s == "ddim": - return self.ddim_sample_loop(*args, **kwargs) - elif s == "dpm++2m": - if self.conditioning_free is not True: - raise RuntimeError("cond_free must be true") - with tqdm(total=self.num_timesteps) as pbar: - return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) - else: - raise RuntimeError("sampler not impl") - - def p_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - ): - """ - Generate samples from the model. - - :param model: the model module. - :param shape: the shape of the samples, (N, C, H, W). - :param noise: if specified, the noise from the encoder to sample. - Should be of the same shape as `shape`. - :param clip_denoised: if True, clip x_start predictions to [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. - :param cond_fn: if not None, this is a gradient function that acts - similarly to the model. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :param device: if specified, the device to create the samples on. - If not specified, use a model parameter's device. - :param progress: if True, show a tqdm progress bar. - :return: a non-differentiable batch of samples. - """ - final = None - for sample in self.p_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - device=device, - progress=progress, - ): - final = sample - return final["sample"] - - def p_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - ): - """ - Generate samples from the model and yield intermediate samples from - each timestep of diffusion. - - Arguments are the same as p_sample_loop(). - Returns a generator over dicts, where each dict is the return value of - p_sample(). - """ - if device is None: - device = next(model.parameters()).device - assert isinstance(shape, (tuple, list)) - if noise is not None: - img = noise - else: - img = th.randn(*shape, device=device) - indices = list(range(self.num_timesteps))[::-1] - - for i in tqdm(indices, disable=not progress): - t = th.tensor([i] * shape[0], device=device) - with th.no_grad(): - out = self.p_sample( - model, - img, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - ) - yield out - img = out["sample"] - - def ddim_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - eta=0.0, - ): - """ - Sample x_{t-1} from the model using DDIM. - - Same usage as p_sample(). - """ - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - if cond_fn is not None: - out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) - - # Usually our model outputs epsilon, but we re-derive it - # in case we used x_start or x_prev prediction. - eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) - - alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) - alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) - sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) - # Equation 12. - noise = th.randn_like(x) - mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 - sample = mean_pred + nonzero_mask * sigma * noise - return {"sample": sample, "pred_xstart": out["pred_xstart"]} - - def ddim_reverse_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - model_kwargs=None, - eta=0.0, - ): - """ - Sample x_{t+1} from the model using DDIM reverse ODE. - """ - assert eta == 0.0, "Reverse ODE only for deterministic path" - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - # Usually our model outputs epsilon, but we re-derive it - # in case we used x_start or x_prev prediction. - eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) - alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) - - # Equation 12. reversed - mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps - - return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} - - def ddim_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, - ): - """ - Generate samples from the model using DDIM. - - Same usage as p_sample_loop(). - """ - final = None - for sample in self.ddim_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - device=device, - progress=progress, - eta=eta, - ): - final = sample - return final["sample"] - - def ddim_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, - ): - """ - Use DDIM to sample from the model and yield intermediate samples from - each timestep of DDIM. - - Same usage as p_sample_loop_progressive(). - """ - if device is None: - device = next(model.parameters()).device - assert isinstance(shape, (tuple, list)) - if noise is not None: - img = noise - else: - img = th.randn(*shape, device=device) - indices = list(range(self.num_timesteps))[::-1] - - if progress: - # Lazy import so that we don't depend on tqdm. - from tqdm.auto import tqdm - - indices = tqdm(indices, disable=not progress) - - for i in indices: - t = th.tensor([i] * shape[0], device=device) - with th.no_grad(): - out = self.ddim_sample( - model, - img, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - eta=eta, - ) - yield out - img = out["sample"] - - def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): - """ - Get a term for the variational lower-bound. - - The resulting units are bits (rather than nats, as one might expect). - This allows for comparison to other papers. - - :return: a dict with the following keys: - - 'output': a shape [N] tensor of NLLs or KLs. - - 'pred_xstart': the x_0 predictions. - """ - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) - out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) - kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) - kl = mean_flat(kl) / np.log(2.0) - - decoder_nll = -discretized_gaussian_log_likelihood( - x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] - ) - assert decoder_nll.shape == x_start.shape - decoder_nll = mean_flat(decoder_nll) / np.log(2.0) - - # At the first timestep return the decoder NLL, - # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) - output = th.where((t == 0), decoder_nll, kl) - return {"output": output, "pred_xstart": out["pred_xstart"]} - - def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): - """ - Compute training losses for a single timestep. - - :param model: the model to evaluate loss on. - :param x_start: the [N x C x ...] tensor of inputs. - :param t: a batch of timestep indices. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :param noise: if specified, the specific Gaussian noise to try to remove. - :return: a dict with the key "loss" containing a tensor of shape [N]. - Some mean or variance settings may also have other keys. - """ - if model_kwargs is None: - model_kwargs = {} - if noise is None: - noise = th.randn_like(x_start) - x_t = self.q_sample(x_start, t, noise=noise) - - terms = {} - - if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: - # TODO: support multiple model outputs for this mode. - terms["loss"] = self._vb_terms_bpd( - model=model, - x_start=x_start, - x_t=x_t, - t=t, - clip_denoised=False, - model_kwargs=model_kwargs, - )["output"] - if self.loss_type == LossType.RESCALED_KL: - terms["loss"] *= self.num_timesteps - elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) - if isinstance(model_outputs, tuple): - model_output = model_outputs[0] - terms["extra_outputs"] = model_outputs[1:] - else: - model_output = model_outputs - - if self.model_var_type in [ - ModelVarType.LEARNED, - ModelVarType.LEARNED_RANGE, - ]: - B, C = x_t.shape[:2] - assert model_output.shape == (B, C * 2, *x_t.shape[2:]) - model_output, model_var_values = th.split(model_output, C, dim=1) - # Learn the variance using the variational bound, but don't let - # it affect our mean prediction. - frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) - terms["vb"] = self._vb_terms_bpd( - model=lambda *args, r=frozen_out: r, - x_start=x_start, - x_t=x_t, - t=t, - clip_denoised=False, - )["output"] - if self.loss_type == LossType.RESCALED_MSE: - # Divide by 1000 for equivalence with initial implementation. - # Without a factor of 1/1000, the VB term hurts the MSE term. - terms["vb"] *= self.num_timesteps / 1000.0 - - if self.model_mean_type == ModelMeanType.PREVIOUS_X: - target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] - x_start_pred = torch.zeros(x_start) # Not supported. - elif self.model_mean_type == ModelMeanType.START_X: - target = x_start - x_start_pred = model_output - elif self.model_mean_type == ModelMeanType.EPSILON: - target = noise - x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) - else: - raise NotImplementedError(self.model_mean_type) - assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) - terms["x_start_predicted"] = x_start_pred - if "vb" in terms: - terms["loss"] = terms["mse"] + terms["vb"] - else: - terms["loss"] = terms["mse"] - else: - raise NotImplementedError(self.loss_type) - - return terms - - def autoregressive_training_losses( - self, - model, - x_start, - t, - model_output_keys, - gd_out_key, - model_kwargs=None, - noise=None, - ): - """ - Compute training losses for a single timestep. - - :param model: the model to evaluate loss on. - :param x_start: the [N x C x ...] tensor of inputs. - :param t: a batch of timestep indices. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :param noise: if specified, the specific Gaussian noise to try to remove. - :return: a dict with the key "loss" containing a tensor of shape [N]. - Some mean or variance settings may also have other keys. - """ - if model_kwargs is None: - model_kwargs = {} - if noise is None: - noise = th.randn_like(x_start) - x_t = self.q_sample(x_start, t, noise=noise) - terms = {} - if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: - assert False # not currently supported for this type of diffusion. - elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) - terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) - model_output = terms[gd_out_key] - if self.model_var_type in [ - ModelVarType.LEARNED, - ModelVarType.LEARNED_RANGE, - ]: - B, C = x_t.shape[:2] - assert model_output.shape == (B, C, 2, *x_t.shape[2:]) - model_output, model_var_values = ( - model_output[:, :, 0], - model_output[:, :, 1], - ) - # Learn the variance using the variational bound, but don't let - # it affect our mean prediction. - frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) - terms["vb"] = self._vb_terms_bpd( - model=lambda *args, r=frozen_out: r, - x_start=x_start, - x_t=x_t, - t=t, - clip_denoised=False, - )["output"] - if self.loss_type == LossType.RESCALED_MSE: - # Divide by 1000 for equivalence with initial implementation. - # Without a factor of 1/1000, the VB term hurts the MSE term. - terms["vb"] *= self.num_timesteps / 1000.0 - - if self.model_mean_type == ModelMeanType.PREVIOUS_X: - target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] - x_start_pred = torch.zeros(x_start) # Not supported. - elif self.model_mean_type == ModelMeanType.START_X: - target = x_start - x_start_pred = model_output - elif self.model_mean_type == ModelMeanType.EPSILON: - target = noise - x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) - else: - raise NotImplementedError(self.model_mean_type) - assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) - terms["x_start_predicted"] = x_start_pred - if "vb" in terms: - terms["loss"] = terms["mse"] + terms["vb"] - else: - terms["loss"] = terms["mse"] - else: - raise NotImplementedError(self.loss_type) - - return terms - - def _prior_bpd(self, x_start): - """ - Get the prior KL term for the variational lower-bound, measured in - bits-per-dim. - - This term can't be optimized, as it only depends on the encoder. - - :param x_start: the [N x C x ...] tensor of inputs. - :return: a batch of [N] KL values (in bits), one per batch element. - """ - batch_size = x_start.shape[0] - t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) - return mean_flat(kl_prior) / np.log(2.0) - - def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): - """ - Compute the entire variational lower-bound, measured in bits-per-dim, - as well as other related quantities. - - :param model: the model to evaluate loss on. - :param x_start: the [N x C x ...] tensor of inputs. - :param clip_denoised: if True, clip denoised samples. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - - :return: a dict containing the following keys: - - total_bpd: the total variational lower-bound, per batch element. - - prior_bpd: the prior term in the lower-bound. - - vb: an [N x T] tensor of terms in the lower-bound. - - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - - mse: an [N x T] tensor of epsilon MSEs for each timestep. - """ - device = x_start.device - batch_size = x_start.shape[0] - - vb = [] - xstart_mse = [] - mse = [] - for t in list(range(self.num_timesteps))[::-1]: - t_batch = th.tensor([t] * batch_size, device=device) - noise = th.randn_like(x_start) - x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) - # Calculate VLB term at the current timestep - with th.no_grad(): - out = self._vb_terms_bpd( - model, - x_start=x_start, - x_t=x_t, - t=t_batch, - clip_denoised=clip_denoised, - model_kwargs=model_kwargs, - ) - vb.append(out["output"]) - xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) - eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) - mse.append(mean_flat((eps - noise) ** 2)) - - vb = th.stack(vb, dim=1) - xstart_mse = th.stack(xstart_mse, dim=1) - mse = th.stack(mse, dim=1) - - prior_bpd = self._prior_bpd(x_start) - total_bpd = vb.sum(dim=1) + prior_bpd - return { - "total_bpd": total_bpd, - "prior_bpd": prior_bpd, - "vb": vb, - "xstart_mse": xstart_mse, - "mse": mse, - } - - -class SpacedDiffusion(GaussianDiffusion): - """ - A diffusion process which can skip steps in a base diffusion process. - - :param use_timesteps: a collection (sequence or set) of timesteps from the - original diffusion process to retain. - :param kwargs: the kwargs to create the base diffusion process. - """ - - def __init__(self, use_timesteps, **kwargs): - self.use_timesteps = set(use_timesteps) - self.timestep_map = [] - self.original_num_steps = len(kwargs["betas"]) - - base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa - last_alpha_cumprod = 1.0 - new_betas = [] - for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): - if i in self.use_timesteps: - new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) - last_alpha_cumprod = alpha_cumprod - self.timestep_map.append(i) - kwargs["betas"] = np.array(new_betas) - super().__init__(**kwargs) - - def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs - return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) - - def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs - return super().training_losses(self._wrap_model(model), *args, **kwargs) - - def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs - return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) - - def condition_mean(self, cond_fn, *args, **kwargs): - return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) - - def condition_score(self, cond_fn, *args, **kwargs): - return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) - - def _wrap_model(self, model, autoregressive=False): - if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): - return model - mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel - return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) - - def _scale_timesteps(self, t): - # Scaling is done by the wrapped model. - return t - - -def space_timesteps(num_timesteps, section_counts): - """ - Create a list of timesteps to use from an original diffusion process, - given the number of timesteps we want to take from equally-sized portions - of the original process. - - For example, if there's 300 timesteps and the section counts are [10,15,20] - then the first 100 timesteps are strided to be 10 timesteps, the second 100 - are strided to be 15 timesteps, and the final 100 are strided to be 20. - - If the stride is a string starting with "ddim", then the fixed striding - from the DDIM paper is used, and only one section is allowed. - - :param num_timesteps: the number of diffusion steps in the original - process to divide up. - :param section_counts: either a list of numbers, or a string containing - comma-separated numbers, indicating the step count - per section. As a special case, use "ddimN" where N - is a number of steps to use the striding from the - DDIM paper. - :return: a set of diffusion steps from the original process to use. - """ - if isinstance(section_counts, str): - if section_counts.startswith("ddim"): - desired_count = int(section_counts[len("ddim") :]) - for i in range(1, num_timesteps): - if len(range(0, num_timesteps, i)) == desired_count: - return set(range(0, num_timesteps, i)) - raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") - section_counts = [int(x) for x in section_counts.split(",")] - size_per = num_timesteps // len(section_counts) - extra = num_timesteps % len(section_counts) - start_idx = 0 - all_steps = [] - for i, section_count in enumerate(section_counts): - size = size_per + (1 if i < extra else 0) - if size < section_count: - raise ValueError(f"cannot divide section of {size} steps into {section_count}") - if section_count <= 1: - frac_stride = 1 - else: - frac_stride = (size - 1) / (section_count - 1) - cur_idx = 0.0 - taken_steps = [] - for _ in range(section_count): - taken_steps.append(start_idx + round(cur_idx)) - cur_idx += frac_stride - all_steps += taken_steps - start_idx += size - return set(all_steps) - - -class _WrappedModel: - def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): - self.model = model - self.timestep_map = timestep_map - self.rescale_timesteps = rescale_timesteps - self.original_num_steps = original_num_steps - - def __call__(self, x, ts, **kwargs): - map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) - new_ts = map_tensor[ts] - if self.rescale_timesteps: - new_ts = new_ts.float() * (1000.0 / self.original_num_steps) - return self.model(x, new_ts, **kwargs) - - -class _WrappedAutoregressiveModel: - def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): - self.model = model - self.timestep_map = timestep_map - self.rescale_timesteps = rescale_timesteps - self.original_num_steps = original_num_steps - - def __call__(self, x, x0, ts, **kwargs): - map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) - new_ts = map_tensor[ts] - if self.rescale_timesteps: - new_ts = new_ts.float() * (1000.0 / self.original_num_steps) - return self.model(x, x0, new_ts, **kwargs) - - -def _extract_into_tensor(arr, timesteps, broadcast_shape): - """ - Extract values from a 1-D numpy array for a batch of indices. - - :param arr: the 1-D numpy array. - :param timesteps: a tensor of indices into the array to extract. - :param broadcast_shape: a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. - """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() - while len(res.shape) < len(broadcast_shape): - res = res[..., None] - return res.expand(broadcast_shape) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 477f31bf..6b8a73e8 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -9,8 +9,6 @@ import torchaudio from coqpit import Coqpit from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel -from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts -from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support @@ -168,12 +166,10 @@ class XttsAudioConfig(Coqpit): Args: sample_rate (int): The sample rate in which the GPT operates. - diffusion_sample_rate (int): The sample rate of the diffusion audio waveform. output_sample_rate (int): The sample rate of the output audio waveform. """ sample_rate: int = 22050 - diffusion_sample_rate: int = 24000 output_sample_rate: int = 24000 @@ -697,24 +693,11 @@ class Xtts(BaseTTS): hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) - else: - assert hasattr( - self, "diffusion_decoder" - ), "You must disable hifigan decoders to use difffusion by setting `use_hifigan: false`" - mel = do_spectrogram_diffusion( - self.diffusion_decoder, - diffuser, - gpt_latents, - diffusion_conditioning, - temperature=diffusion_temperature, - ) - wav = self.vocoder.inference(mel) return { "wav": wav.cpu().numpy().squeeze(), "gpt_latents": gpt_latents, "speaker_embedding": speaker_embedding, - "diffusion_conditioning": diffusion_conditioning, } def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index 9134be0d..65d3ccd0 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -98,7 +98,7 @@ def main(): ) # define audio config audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 ) # training parameters config config = GPTTrainerConfig( diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index ee6b22be..3bb68e2f 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -99,7 +99,7 @@ def main(): ) # define audio config audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 ) # training parameters config config = GPTTrainerConfig( diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index 47b1dd7d..09df98ef 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -89,7 +89,7 @@ model_args = GPTArgs( use_ne_hifigan=True, ) audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 ) config = GPTTrainerConfig( epochs=1, diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index 6b6f1330..0851a4e2 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -89,7 +89,7 @@ model_args = GPTArgs( use_ne_hifigan=True, ) audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 ) config = GPTTrainerConfig( epochs=1,