mirror of https://github.com/coqui-ai/TTS.git
tune_wavegrad update
This commit is contained in:
parent
d8c1b5b73d
commit
4b92ac0f92
|
@ -34,7 +34,7 @@ _, train_data = load_wav_data(args.data_path, 0)
|
||||||
train_data = train_data[:args.num_samples]
|
train_data = train_data[:args.num_samples]
|
||||||
dataset = WaveGradDataset(ap=ap,
|
dataset = WaveGradDataset(ap=ap,
|
||||||
items=train_data,
|
items=train_data,
|
||||||
seq_len=ap.hop_length * 100,
|
seq_len=-1,
|
||||||
hop_len=ap.hop_length,
|
hop_len=ap.hop_length,
|
||||||
pad_short=config.pad_short,
|
pad_short=config.pad_short,
|
||||||
conv_pad=config.conv_pad,
|
conv_pad=config.conv_pad,
|
||||||
|
@ -58,8 +58,9 @@ if args.use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
||||||
# setup optimization parameters
|
# setup optimization parameters
|
||||||
base_values = sorted(np.random.uniform(high=10, size=args.search_depth))
|
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
||||||
exponents = 10 ** np.linspace(-6, -2, num=args.num_iter)
|
print(base_values)
|
||||||
|
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
||||||
best_error = float('inf')
|
best_error = float('inf')
|
||||||
best_schedule = None
|
best_schedule = None
|
||||||
total_search_iter = len(base_values)**args.num_iter
|
total_search_iter = len(base_values)**args.num_iter
|
||||||
|
|
|
@ -119,6 +119,7 @@ class Wavegrad(nn.Module):
|
||||||
alpha = 1 - beta
|
alpha = 1 - beta
|
||||||
alpha_hat = np.cumprod(alpha)
|
alpha_hat = np.cumprod(alpha)
|
||||||
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
|
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
|
||||||
|
noise_level = alpha_hat ** 0.5
|
||||||
|
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
self.beta = torch.tensor(beta.astype(np.float32))
|
self.beta = torch.tensor(beta.astype(np.float32))
|
||||||
|
|
Loading…
Reference in New Issue