wavegrad updates

This commit is contained in:
erogol 2020-10-19 15:44:07 +02:00
parent a1582a0e12
commit 7bcdb7ac35
6 changed files with 129 additions and 66 deletions

View File

@ -50,22 +50,35 @@ def setup_loader(ap, is_val=False, verbose=False):
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=c.batch_size,
shuffle=False if num_gpus > 1 else True,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
return loader
def format_data(data):
# return a whole audio segment
m, y = data
m, x = data
if use_cuda:
m = m.cuda(non_blocking=True)
y = y.cuda(non_blocking=True)
return m, y
x = x.cuda(non_blocking=True)
return m, x
def format_test_data(data):
# return a whole audio segment
m, x = data
m = m.unsqueeze(0)
x = x.unsqueeze(0)
if use_cuda:
m = m.cuda(non_blocking=True)
x = x.cuda(non_blocking=True)
return m, x
def train(model, criterion, optimizer,
@ -81,26 +94,36 @@ def train(model, criterion, optimizer,
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
# setup noise schedule
noise_schedule = c['train_noise_schedule']
if hasattr(model, 'module'):
model.module.init_noise_schedule(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
else:
model.init_noise_schedule(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
m, y = format_data(data)
m, x = format_data(data)
loader_time = time.time() - end_time
global_step += 1
# compute noisy input
if hasattr(model, 'module'):
y_noisy, noise_scale = model.module.compute_noisy_x(y)
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
else:
y_noisy, noise_scale = model.compute_noisy_x(y)
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
# forward pass
y_hat = model(y_noisy, m, noise_scale)
noise_hat = model(x_noisy, m, noise_scale)
# compute losses
loss = criterion(y_noisy, y_hat)
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
# backward pass with loss scaling
@ -181,15 +204,6 @@ def train(model, criterion, optimizer,
OUT_PATH,
model_losses=loss_dict)
# compute spectrograms
figures = plot_results(y_hat[0], y[0], ap, global_step, 'train')
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice},
c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
@ -218,23 +232,23 @@ def evaluate(model, criterion, ap, global_step, epoch):
start_time = time.time()
# format data
m, y = format_data(data)
m, x = format_data(data)
loader_time = time.time() - end_time
global_step += 1
# compute noisy input
if hasattr(model, 'module'):
y_noisy, noise_scale = model.module.compute_noisy_x(y)
noise, x_noisy, noise_scale = model.module.compute_noisy_x(x)
else:
y_noisy, noise_scale = model.compute_noisy_x(y)
noise, x_noisy, noise_scale = model.compute_noisy_x(x)
# forward pass
y_hat = model(y_noisy, m, noise_scale)
noise_hat = model(x_noisy, m, noise_scale)
# compute losses
loss = criterion(y_noisy, y_hat)
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
@ -261,14 +275,32 @@ def evaluate(model, criterion, ap, global_step, epoch):
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0:
samples = data_loader.dataset.load_test_samples(1)
m, x = format_test_data(samples[0])
# setup noise schedule and inference
noise_schedule = c['test_noise_schedule']
if hasattr(model, 'module'):
model.module.init_noise_schedule(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
# compute voice
x_pred = model.module.inference(m)
else:
model.init_noise_schedule(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
# compute voice
x_pred = model.inference(m)
# compute spectrograms
figures = plot_results(y_hat, y, ap, global_step, 'eval')
figures = plot_results(x_pred, x, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
c.audio["sample_rate"])
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)

View File

@ -92,8 +92,8 @@
// DATASET
"data_path": "/home/erogol/Data/MozillaMerged22050/wavs/",
"feature_path": null,
"seq_len": 16384,
"pad_short": 2000,
"seq_len": 6144,
"pad_short": 500,
"conv_pad": 0,
"use_noise_augment": false,
"use_cache": true,
@ -102,6 +102,16 @@
// TRAINING
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"train_noise_schedule":{
"min_val": 1e-6,
"max_val": 1e-2,
"num_steps": 1000
},
"test_noise_schedule":{
"min_val": 1e-6,
"max_val": 1e-2,
"num_steps": 50
}
// VALIDATION
"run_eval": true,

View File

@ -71,6 +71,16 @@
// TRAINING
"batch_size": 64, // Batch size for training.
"train_noise_schedule":{
"min_val": 1e-6,
"max_val": 1e-2,
"num_steps": 1000
},
"test_noise_schedule":{
"min_val": 1e-6,
"max_val": 1e-2,
"num_steps": 50
},
// VALIDATION
"run_eval": true, // enable/disable evaluation run

View File

@ -62,6 +62,13 @@ class WaveGradDataset(Dataset):
item = self.load_item(idx)
return item
def load_test_samples(self, num_samples):
samples = []
for idx in range(num_samples):
mel, audio = self.load_item(idx)
samples.append([mel, audio])
return samples
def load_item(self, idx):
""" load (audio, feat) couple """
if self.compute_feat:

View File

@ -4,6 +4,16 @@ from torch import nn
from torch.nn import functional as F
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reset_parameters()
def reset_parameters(self):
nn.init.orthogonal_(self.weight)
nn.init.zeros_(self.bias)
class NoiseLevelEncoding(nn.Module):
"""Noise level encoding applying same
encoding vector to all time steps. It is
@ -25,7 +35,8 @@ class NoiseLevelEncoding(nn.Module):
"""
return (x + self.encoding(noise_level)[:, :, None])
def init_encoding(self, length):
@staticmethod
def init_encoding(length):
div_by = torch.arange(length) / length
enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0))
return enc
@ -46,8 +57,8 @@ class FiLM(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.encoding = NoiseLevelEncoding(in_channels)
self.conv_in = nn.Conv1d(in_channels, in_channels, 3, padding=1)
self.conv_out = nn.Conv1d(in_channels, out_channels * 2, 3, padding=1)
self.conv_in = Conv1d(in_channels, in_channels, 3, padding=1)
self.conv_out = Conv1d(in_channels, out_channels * 2, 3, padding=1)
self._init_parameters()
def _init_parameters(self):
@ -74,26 +85,26 @@ class UBlock(nn.Module):
assert len(dilations) == 4
self.upsample_factor = upsample_factor
self.shortcut_conv = nn.Conv1d(in_channels, hid_channels, 1)
self.shortcut_conv = Conv1d(in_channels, hid_channels, 1)
self.main_block1 = nn.ModuleList([
nn.Conv1d(in_channels,
Conv1d(in_channels,
hid_channels,
3,
dilation=dilations[0],
padding=dilations[0]),
nn.Conv1d(hid_channels,
Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[1],
padding=dilations[1])
])
self.main_block2 = nn.ModuleList([
nn.Conv1d(hid_channels,
Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[2],
padding=dilations[2]),
nn.Conv1d(hid_channels,
Conv1d(hid_channels,
hid_channels,
3,
dilation=dilations[3],
@ -129,11 +140,11 @@ class DBlock(nn.Module):
def __init__(self, in_channels, hid_channels, downsample_factor):
super().__init__()
self.downsample_factor = downsample_factor
self.res_conv = nn.Conv1d(in_channels, hid_channels, 1)
self.res_conv = Conv1d(in_channels, hid_channels, 1)
self.main_convs = nn.ModuleList([
nn.Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
nn.Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
nn.Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1),
Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2),
Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4),
])
def forward(self, x):

View File

@ -22,14 +22,6 @@ class Wavegrad(nn.Module):
assert len(upsample_factors) == len(upsample_dilations)
assert len(upsample_factors) == len(ublock_out_channels)
# inference time noise schedule params
self.S = 1000
beta, alpha, alpha_cum, noise_level = self._setup_noise_level()
self.register_buffer('beta', beta)
self.register_buffer('alpha', alpha)
self.register_buffer('alpha_cum', alpha_cum)
self.register_buffer('noise_level', noise_level)
# setup up-down sampling parameters
self.hop_length = np.prod(upsample_factors)
self.upsample_factors = upsample_factors
@ -68,21 +60,23 @@ class Wavegrad(nn.Module):
# print(ic, 'last_conv--', out_channels)
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
def _setup_noise_level(self, noise_schedule=None):
"""compute noise schedule parameters"""
if noise_schedule is None:
beta = np.linspace(1e-6, 0.01, self.S)
else:
beta = noise_schedule
alpha = 1 - beta
alpha_cum = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_cum ** 0.5], axis=0)
# inference time noise schedule params
self.S = 1000
self.init_noise_schedule(self.S)
beta = torch.from_numpy(beta)
alpha = torch.from_numpy(alpha)
alpha_cum = torch.from_numpy(alpha_cum)
noise_level = torch.from_numpy(noise_level.astype(np.float32))
return beta, alpha, alpha_cum, noise_level
def init_noise_schedule(self, num_iter, min_val=1e-6, max_val=0.01):
"""compute noise schedule parameters"""
device = self.last_conv.weight.device
beta = torch.linspace(min_val, max_val, num_iter).to(device)
alpha = 1 - beta
alpha_cum = alpha.cumprod(dim=0)
noise_level = torch.cat([torch.FloatTensor([1]).to(device), alpha_cum ** 0.5])
self.register_buffer('beta', beta)
self.register_buffer('alpha', alpha)
self.register_buffer('alpha_cum', alpha_cum)
self.register_buffer('noise_level', noise_level)
def compute_noisy_x(self, x):
B = x.shape[0]
@ -94,7 +88,7 @@ class Wavegrad(nn.Module):
noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(x)
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
return noisy_x.unsqueeze(1), noise_scale[:, 0]
return noise.unsqueeze(1), noisy_x.unsqueeze(1), noise_scale[:, 0]
def forward(self, x, c, noise_scale):
assert len(c.shape) == 3 # B, C, T
@ -114,9 +108,8 @@ class Wavegrad(nn.Module):
def inference(self, c):
with torch.no_grad():
x = torch.randn(c.shape[0], self.hop_length * c.shape[-1]).to(c)
noise_scale = torch.from_numpy(
self.alpha_cum**0.5).float().unsqueeze(1).to(c)
x = torch.randn(c.shape[0], 1, self.hop_length * c.shape[-1]).to(c)
noise_scale = (self.alpha_cum**0.5).unsqueeze(1).to(c)
for n in range(len(self.alpha) - 1, -1, -1):
c1 = 1 / self.alpha[n]**0.5
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5