From 7bcdb7ac3540da3a9377ff8a47ff227041a06963 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 19 Oct 2020 15:44:07 +0200 Subject: [PATCH] wavegrad updates --- TTS/bin/train_wavegrad.py | 84 +++++++++++++------ .../multiband_melgan_config_mozilla.json | 14 +++- TTS/vocoder/configs/wavegrad_libritts.json | 10 +++ TTS/vocoder/datasets/wavegrad_dataset.py | 7 ++ TTS/vocoder/layers/wavegrad.py | 35 +++++--- TTS/vocoder/models/wavegrad.py | 45 +++++----- 6 files changed, 129 insertions(+), 66 deletions(-) diff --git a/TTS/bin/train_wavegrad.py b/TTS/bin/train_wavegrad.py index 6d17b4f2..e167a4cb 100644 --- a/TTS/bin/train_wavegrad.py +++ b/TTS/bin/train_wavegrad.py @@ -50,22 +50,35 @@ def setup_loader(ap, is_val=False, verbose=False): sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader(dataset, batch_size=c.batch_size, - shuffle=False if num_gpus > 1 else True, + shuffle=num_gpus <= 1, drop_last=False, sampler=sampler, num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, pin_memory=False) + + return loader def format_data(data): # return a whole audio segment - m, y = data + m, x = data if use_cuda: m = m.cuda(non_blocking=True) - y = y.cuda(non_blocking=True) - return m, y + x = x.cuda(non_blocking=True) + return m, x + + +def format_test_data(data): + # return a whole audio segment + m, x = data + m = m.unsqueeze(0) + x = x.unsqueeze(0) + if use_cuda: + m = m.cuda(non_blocking=True) + x = x.cuda(non_blocking=True) + return m, x def train(model, criterion, optimizer, @@ -81,26 +94,36 @@ def train(model, criterion, optimizer, batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() + # setup noise schedule + noise_schedule = c['train_noise_schedule'] + if hasattr(model, 'module'): + model.module.init_noise_schedule(noise_schedule['num_steps'], + noise_schedule['min_val'], + noise_schedule['max_val']) + else: + model.init_noise_schedule(noise_schedule['num_steps'], + noise_schedule['min_val'], + noise_schedule['max_val']) for num_iter, data in enumerate(data_loader): start_time = time.time() # format data - m, y = format_data(data) + m, x = format_data(data) loader_time = time.time() - end_time global_step += 1 # compute noisy input if hasattr(model, 'module'): - y_noisy, noise_scale = model.module.compute_noisy_x(y) + noise, x_noisy, noise_scale = model.module.compute_noisy_x(x) else: - y_noisy, noise_scale = model.compute_noisy_x(y) + noise, x_noisy, noise_scale = model.compute_noisy_x(x) # forward pass - y_hat = model(y_noisy, m, noise_scale) + noise_hat = model(x_noisy, m, noise_scale) # compute losses - loss = criterion(y_noisy, y_hat) + loss = criterion(noise, noise_hat) loss_wavegrad_dict = {'wavegrad_loss':loss} # backward pass with loss scaling @@ -181,15 +204,6 @@ def train(model, criterion, optimizer, OUT_PATH, model_losses=loss_dict) - # compute spectrograms - figures = plot_results(y_hat[0], y[0], ap, global_step, 'train') - tb_logger.tb_train_figures(global_step, figures) - - # Sample audio - sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_train_audios(global_step, - {'train/audio': sample_voice}, - c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -218,23 +232,23 @@ def evaluate(model, criterion, ap, global_step, epoch): start_time = time.time() # format data - m, y = format_data(data) + m, x = format_data(data) loader_time = time.time() - end_time global_step += 1 # compute noisy input if hasattr(model, 'module'): - y_noisy, noise_scale = model.module.compute_noisy_x(y) + noise, x_noisy, noise_scale = model.module.compute_noisy_x(x) else: - y_noisy, noise_scale = model.compute_noisy_x(y) + noise, x_noisy, noise_scale = model.compute_noisy_x(x) # forward pass - y_hat = model(y_noisy, m, noise_scale) + noise_hat = model(x_noisy, m, noise_scale) # compute losses - loss = criterion(y_noisy, y_hat) + loss = criterion(noise, noise_hat) loss_wavegrad_dict = {'wavegrad_loss':loss} @@ -261,14 +275,32 @@ def evaluate(model, criterion, ap, global_step, epoch): c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if args.rank == 0: + samples = data_loader.dataset.load_test_samples(1) + m, x = format_test_data(samples[0]) + + # setup noise schedule and inference + noise_schedule = c['test_noise_schedule'] + if hasattr(model, 'module'): + model.module.init_noise_schedule(noise_schedule['num_steps'], + noise_schedule['min_val'], + noise_schedule['max_val']) + # compute voice + x_pred = model.module.inference(m) + else: + model.init_noise_schedule(noise_schedule['num_steps'], + noise_schedule['min_val'], + noise_schedule['max_val']) + # compute voice + x_pred = model.inference(m) + # compute spectrograms - figures = plot_results(y_hat, y, ap, global_step, 'eval') + figures = plot_results(x_pred, x, ap, global_step, 'eval') tb_logger.tb_eval_figures(global_step, figures) # Sample audio - sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy() tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, - c.audio["sample_rate"]) + c.audio["sample_rate"]) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) diff --git a/TTS/vocoder/configs/multiband_melgan_config_mozilla.json b/TTS/vocoder/configs/multiband_melgan_config_mozilla.json index 35f1642a..4978d42f 100644 --- a/TTS/vocoder/configs/multiband_melgan_config_mozilla.json +++ b/TTS/vocoder/configs/multiband_melgan_config_mozilla.json @@ -92,8 +92,8 @@ // DATASET "data_path": "/home/erogol/Data/MozillaMerged22050/wavs/", "feature_path": null, - "seq_len": 16384, - "pad_short": 2000, + "seq_len": 6144, + "pad_short": 500, "conv_pad": 0, "use_noise_augment": false, "use_cache": true, @@ -102,6 +102,16 @@ // TRAINING "batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "train_noise_schedule":{ + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 1000 + }, + "test_noise_schedule":{ + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 50 + } // VALIDATION "run_eval": true, diff --git a/TTS/vocoder/configs/wavegrad_libritts.json b/TTS/vocoder/configs/wavegrad_libritts.json index 98de36c7..9bb1154b 100644 --- a/TTS/vocoder/configs/wavegrad_libritts.json +++ b/TTS/vocoder/configs/wavegrad_libritts.json @@ -71,6 +71,16 @@ // TRAINING "batch_size": 64, // Batch size for training. + "train_noise_schedule":{ + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 1000 + }, + "test_noise_schedule":{ + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 50 + }, // VALIDATION "run_eval": true, // enable/disable evaluation run diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 4a70c252..1e4f9e11 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -62,6 +62,13 @@ class WaveGradDataset(Dataset): item = self.load_item(idx) return item + def load_test_samples(self, num_samples): + samples = [] + for idx in range(num_samples): + mel, audio = self.load_item(idx) + samples.append([mel, audio]) + return samples + def load_item(self, idx): """ load (audio, feat) couple """ if self.compute_feat: diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index 69bca0a8..c7549676 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -4,6 +4,16 @@ from torch import nn from torch.nn import functional as F +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.reset_parameters() + + def reset_parameters(self): + nn.init.orthogonal_(self.weight) + nn.init.zeros_(self.bias) + + class NoiseLevelEncoding(nn.Module): """Noise level encoding applying same encoding vector to all time steps. It is @@ -25,7 +35,8 @@ class NoiseLevelEncoding(nn.Module): """ return (x + self.encoding(noise_level)[:, :, None]) - def init_encoding(self, length): + @staticmethod + def init_encoding(length): div_by = torch.arange(length) / length enc = torch.exp(-math.log(1e4) * div_by.unsqueeze(0)) return enc @@ -46,8 +57,8 @@ class FiLM(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.encoding = NoiseLevelEncoding(in_channels) - self.conv_in = nn.Conv1d(in_channels, in_channels, 3, padding=1) - self.conv_out = nn.Conv1d(in_channels, out_channels * 2, 3, padding=1) + self.conv_in = Conv1d(in_channels, in_channels, 3, padding=1) + self.conv_out = Conv1d(in_channels, out_channels * 2, 3, padding=1) self._init_parameters() def _init_parameters(self): @@ -74,26 +85,26 @@ class UBlock(nn.Module): assert len(dilations) == 4 self.upsample_factor = upsample_factor - self.shortcut_conv = nn.Conv1d(in_channels, hid_channels, 1) + self.shortcut_conv = Conv1d(in_channels, hid_channels, 1) self.main_block1 = nn.ModuleList([ - nn.Conv1d(in_channels, + Conv1d(in_channels, hid_channels, 3, dilation=dilations[0], padding=dilations[0]), - nn.Conv1d(hid_channels, + Conv1d(hid_channels, hid_channels, 3, dilation=dilations[1], padding=dilations[1]) ]) self.main_block2 = nn.ModuleList([ - nn.Conv1d(hid_channels, + Conv1d(hid_channels, hid_channels, 3, dilation=dilations[2], padding=dilations[2]), - nn.Conv1d(hid_channels, + Conv1d(hid_channels, hid_channels, 3, dilation=dilations[3], @@ -129,11 +140,11 @@ class DBlock(nn.Module): def __init__(self, in_channels, hid_channels, downsample_factor): super().__init__() self.downsample_factor = downsample_factor - self.res_conv = nn.Conv1d(in_channels, hid_channels, 1) + self.res_conv = Conv1d(in_channels, hid_channels, 1) self.main_convs = nn.ModuleList([ - nn.Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1), - nn.Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2), - nn.Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4), + Conv1d(in_channels, hid_channels, 3, dilation=1, padding=1), + Conv1d(hid_channels, hid_channels, 3, dilation=2, padding=2), + Conv1d(hid_channels, hid_channels, 3, dilation=4, padding=4), ]) def forward(self, x): diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 6405bea8..95e5b03a 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -22,14 +22,6 @@ class Wavegrad(nn.Module): assert len(upsample_factors) == len(upsample_dilations) assert len(upsample_factors) == len(ublock_out_channels) - # inference time noise schedule params - self.S = 1000 - beta, alpha, alpha_cum, noise_level = self._setup_noise_level() - self.register_buffer('beta', beta) - self.register_buffer('alpha', alpha) - self.register_buffer('alpha_cum', alpha_cum) - self.register_buffer('noise_level', noise_level) - # setup up-down sampling parameters self.hop_length = np.prod(upsample_factors) self.upsample_factors = upsample_factors @@ -68,21 +60,23 @@ class Wavegrad(nn.Module): # print(ic, 'last_conv--', out_channels) self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1) - def _setup_noise_level(self, noise_schedule=None): - """compute noise schedule parameters""" - if noise_schedule is None: - beta = np.linspace(1e-6, 0.01, self.S) - else: - beta = noise_schedule - alpha = 1 - beta - alpha_cum = np.cumprod(alpha) - noise_level = np.concatenate([[1.0], alpha_cum ** 0.5], axis=0) + # inference time noise schedule params + self.S = 1000 + self.init_noise_schedule(self.S) - beta = torch.from_numpy(beta) - alpha = torch.from_numpy(alpha) - alpha_cum = torch.from_numpy(alpha_cum) - noise_level = torch.from_numpy(noise_level.astype(np.float32)) - return beta, alpha, alpha_cum, noise_level + + def init_noise_schedule(self, num_iter, min_val=1e-6, max_val=0.01): + """compute noise schedule parameters""" + device = self.last_conv.weight.device + beta = torch.linspace(min_val, max_val, num_iter).to(device) + alpha = 1 - beta + alpha_cum = alpha.cumprod(dim=0) + noise_level = torch.cat([torch.FloatTensor([1]).to(device), alpha_cum ** 0.5]) + + self.register_buffer('beta', beta) + self.register_buffer('alpha', alpha) + self.register_buffer('alpha_cum', alpha_cum) + self.register_buffer('noise_level', noise_level) def compute_noisy_x(self, x): B = x.shape[0] @@ -94,7 +88,7 @@ class Wavegrad(nn.Module): noise_scale = noise_scale.unsqueeze(1) noise = torch.randn_like(x) noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise - return noisy_x.unsqueeze(1), noise_scale[:, 0] + return noise.unsqueeze(1), noisy_x.unsqueeze(1), noise_scale[:, 0] def forward(self, x, c, noise_scale): assert len(c.shape) == 3 # B, C, T @@ -114,9 +108,8 @@ class Wavegrad(nn.Module): def inference(self, c): with torch.no_grad(): - x = torch.randn(c.shape[0], self.hop_length * c.shape[-1]).to(c) - noise_scale = torch.from_numpy( - self.alpha_cum**0.5).float().unsqueeze(1).to(c) + x = torch.randn(c.shape[0], 1, self.hop_length * c.shape[-1]).to(c) + noise_scale = (self.alpha_cum**0.5).unsqueeze(1).to(c) for n in range(len(self.alpha) - 1, -1, -1): c1 = 1 / self.alpha[n]**0.5 c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5