mirror of https://github.com/coqui-ai/TTS.git
change noise scheduling for wavegrad. Compute beta values externally to enable better flexibility
This commit is contained in:
parent
5a59467f34
commit
c65712426a
|
@ -59,11 +59,13 @@ if args.use_cuda:
|
||||||
|
|
||||||
# setup optimization parameters
|
# setup optimization parameters
|
||||||
base_values = sorted(np.random.uniform(high=10, size=args.search_depth))
|
base_values = sorted(np.random.uniform(high=10, size=args.search_depth))
|
||||||
|
exponents = 10 ** np.linspace(-6, -2, num=args.num_iter)
|
||||||
best_error = float('inf')
|
best_error = float('inf')
|
||||||
best_schedule = None
|
best_schedule = None
|
||||||
total_search_iter = len(base_values)**args.num_iter
|
total_search_iter = len(base_values)**args.num_iter
|
||||||
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||||
model.compute_noise_level(num_steps=args.num_iter, min_val=1e-6, max_val=1e-1, base_vals=base)
|
beta = exponents * base
|
||||||
|
model.compute_noise_level(beta)
|
||||||
for data in loader:
|
for data in loader:
|
||||||
mel, audio = data
|
mel, audio = data
|
||||||
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
y_hat = model.inference(mel.cuda() if args.use_cuda else mel)
|
||||||
|
@ -78,11 +80,11 @@ for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=tot
|
||||||
mel_hat.append(torch.from_numpy(m))
|
mel_hat.append(torch.from_numpy(m))
|
||||||
|
|
||||||
mel_hat = torch.stack(mel_hat)
|
mel_hat = torch.stack(mel_hat)
|
||||||
mse = torch.sum((mel - mel_hat) ** 2)
|
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
||||||
if mse.item() < best_error:
|
if mse.item() < best_error:
|
||||||
best_error = mse.item()
|
best_error = mse.item()
|
||||||
best_schedule = {'num_steps': args.num_iter, 'min_val':1e-6, 'max_val':1e-1, 'base_vals':base}
|
best_schedule = {'beta': beta}
|
||||||
print(" > Found a better schedule.")
|
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
||||||
np.save(args.output_path, best_schedule)
|
np.save(args.output_path, best_schedule)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,8 @@
|
||||||
|
|
||||||
// TRAINING
|
// TRAINING
|
||||||
"batch_size": 96, // Batch size for training.
|
"batch_size": 96, // Batch size for training.
|
||||||
|
|
||||||
|
// NOISE SCHEDULE PARAMS - Only effective at training time.
|
||||||
"train_noise_schedule":{
|
"train_noise_schedule":{
|
||||||
"min_val": 1e-6,
|
"min_val": 1e-6,
|
||||||
"max_val": 1e-2,
|
"max_val": 1e-2,
|
||||||
|
|
|
@ -79,8 +79,8 @@ class Wavegrad(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def load_noise_schedule(self, path):
|
def load_noise_schedule(self, path):
|
||||||
sched = np.load(path, allow_pickle=True).item()
|
beta = np.load(path, allow_pickle=True).item()['beta']
|
||||||
self.compute_noise_level(**sched)
|
self.compute_noise_level(beta)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, y_n=None):
|
def inference(self, x, y_n=None):
|
||||||
|
@ -113,16 +113,13 @@ class Wavegrad(nn.Module):
|
||||||
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
|
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
|
||||||
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
|
||||||
|
|
||||||
def compute_noise_level(self, num_steps, min_val, max_val, base_vals=None):
|
def compute_noise_level(self, beta):
|
||||||
"""Compute noise schedule parameters"""
|
"""Compute noise schedule parameters"""
|
||||||
beta = np.linspace(min_val, max_val, num_steps)
|
self.num_steps = len(beta)
|
||||||
if base_vals is not None:
|
|
||||||
beta *= base_vals
|
|
||||||
alpha = 1 - beta
|
alpha = 1 - beta
|
||||||
alpha_hat = np.cumprod(alpha)
|
alpha_hat = np.cumprod(alpha)
|
||||||
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
|
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
|
||||||
|
|
||||||
self.num_steps = num_steps
|
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
self.beta = torch.tensor(beta.astype(np.float32))
|
self.beta = torch.tensor(beta.astype(np.float32))
|
||||||
self.alpha = torch.tensor(alpha.astype(np.float32))
|
self.alpha = torch.tensor(alpha.astype(np.float32))
|
||||||
|
|
Loading…
Reference in New Issue