mirror of https://github.com/coqui-ai/TTS.git
Make k_diffusion optional
This commit is contained in:
parent
08d11e9198
commit
26efdf6ee7
|
@ -13,12 +13,19 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch as th
|
||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||
|
||||
|
||||
try:
|
||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||
|
||||
K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
|
||||
except ImportError:
|
||||
K_DIFFUSION_SAMPLERS = None
|
||||
|
||||
|
||||
SAMPLERS = ["dpm++2m", "p", "ddim"]
|
||||
|
||||
|
||||
|
@ -531,6 +538,8 @@ class GaussianDiffusion:
|
|||
if self.conditioning_free is not True:
|
||||
raise RuntimeError("cond_free must be true")
|
||||
with tqdm(total=self.num_timesteps) as pbar:
|
||||
if K_DIFFUSION_SAMPLERS is None:
|
||||
raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers")
|
||||
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
|
||||
else:
|
||||
raise RuntimeError("sampler not impl")
|
||||
|
|
|
@ -46,7 +46,6 @@ bangla
|
|||
bnnumerizer
|
||||
bnunicodenormalizer
|
||||
#deps for tortoise
|
||||
k_diffusion
|
||||
einops>=0.6.0
|
||||
transformers>=4.33.0
|
||||
#deps for bark
|
||||
|
|
Loading…
Reference in New Issue