From b76a0be97a8c67df494d5767e75d211184cb2787 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 26 Oct 2020 16:47:18 +0100 Subject: [PATCH] wavegrad model and layers refactoring --- TTS/vocoder/models/wavegrad.py | 171 +++++++++++++---------------- TTS/vocoder/utils/generic_utils.py | 2 +- 2 files changed, 77 insertions(+), 96 deletions(-) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 95e5b03a..cbdb1205 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -2,7 +2,7 @@ import numpy as np import torch from torch import nn -from ..layers.wavegrad import DBlock, FiLM, UBlock +from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d class Wavegrad(nn.Module): @@ -10,8 +10,8 @@ class Wavegrad(nn.Module): def __init__(self, in_channels=80, out_channels=1, - x_conv_channels=32, - c_conv_channels=768, + y_conv_channels=32, + x_conv_channels=768, dblock_out_channels=[128, 128, 256, 512], ublock_out_channels=[512, 512, 256, 128, 128], upsample_factors=[5, 5, 3, 2, 2], @@ -19,106 +19,87 @@ class Wavegrad(nn.Module): [1, 2, 4, 8], [1, 2, 4, 8]]): super().__init__() - assert len(upsample_factors) == len(upsample_dilations) - assert len(upsample_factors) == len(ublock_out_channels) + self.hop_len = np.prod(upsample_factors) - # setup up-down sampling parameters - self.hop_length = np.prod(upsample_factors) - self.upsample_factors = upsample_factors - self.downsample_factors = upsample_factors[::-1][:-1] - - ### define DBlocks, FiLM layers ### + # dblocks self.dblocks = nn.ModuleList([ - nn.Conv1d(out_channels, x_conv_channels, 5, padding=2), + Conv1d(1, y_conv_channels, 5, padding=2), ]) - ic = x_conv_channels - self.films = nn.ModuleList([]) - for oc, df in zip(dblock_out_channels, self.downsample_factors): - # print('dblock(', ic, ', ', oc, ', ', df, ")") - layer = DBlock(ic, oc, df) - self.dblocks.append(layer) - - # print('film(', ic, ', ', oc,")") - layer = FiLM(ic, oc) - self.films.append(layer) + ic = y_conv_channels + for oc, df in zip(dblock_out_channels, reversed(upsample_factors)): + self.dblocks.append(DBlock(ic, oc, df)) ic = oc - # last FiLM block - # print('film(', ic, ', ', dblock_out_channels[-1],")") - self.films.append(FiLM(ic, dblock_out_channels[-1])) - ### define UBlocks ### - self.c_conv = nn.Conv1d(in_channels, c_conv_channels, 3, padding=1) + # film + self.film = nn.ModuleList([]) + ic = y_conv_channels + for oc in reversed(ublock_out_channels): + self.film.append(FiLM(ic, oc)) + ic = oc + + # ublocks self.ublocks = nn.ModuleList([]) - ic = c_conv_channels - for idx, (oc, uf) in enumerate(zip(ublock_out_channels, self.upsample_factors)): - # print('ublock(', ic, ', ', oc, ', ', uf, ")") - layer = UBlock(ic, oc, uf, upsample_dilations[idx]) - self.ublocks.append(layer) + ic = x_conv_channels + for oc, uf, ud in zip(ublock_out_channels, upsample_factors, upsample_dilations): + self.ublocks.append(UBlock(ic, oc, uf, ud)) ic = oc - # define last layer - # print(ic, 'last_conv--', out_channels) - self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1) + self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1) + self.out_conv = Conv1d(oc, out_channels, 3, padding=1) - # inference time noise schedule params - self.S = 1000 - self.init_noise_schedule(self.S) + def forward(self, x, spectrogram, noise_scale): + downsampled = [] + for film, layer in zip(self.film, self.dblocks): + x = layer(x) + downsampled.append(film(x, noise_scale)) - - def init_noise_schedule(self, num_iter, min_val=1e-6, max_val=0.01): - """compute noise schedule parameters""" - device = self.last_conv.weight.device - beta = torch.linspace(min_val, max_val, num_iter).to(device) - alpha = 1 - beta - alpha_cum = alpha.cumprod(dim=0) - noise_level = torch.cat([torch.FloatTensor([1]).to(device), alpha_cum ** 0.5]) - - self.register_buffer('beta', beta) - self.register_buffer('alpha', alpha) - self.register_buffer('alpha_cum', alpha_cum) - self.register_buffer('noise_level', noise_level) - - def compute_noisy_x(self, x): - B = x.shape[0] - if len(x.shape) == 3: - x = x.squeeze(1) - s = torch.randint(1, self.S + 1, [B]).to(x).long() - l_a, l_b = self.noise_level[s-1], self.noise_level[s] - noise_scale = l_a + torch.rand(B).to(x) * (l_b - l_a) - noise_scale = noise_scale.unsqueeze(1) - noise = torch.randn_like(x) - noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise - return noise.unsqueeze(1), noisy_x.unsqueeze(1), noise_scale[:, 0] - - def forward(self, x, c, noise_scale): - assert len(c.shape) == 3 # B, C, T - assert len(x.shape) == 3 # B, 1, T - o = x - shift_and_scales = [] - for film, dblock in zip(self.films, self.dblocks): - o = dblock(o) - shift_and_scales.append(film(o, noise_scale)) - - o = self.c_conv(c) - for ublock, (film_shift, film_scale) in zip(self.ublocks, - reversed(shift_and_scales)): - o = ublock(o, film_shift, film_scale) - o = self.last_conv(o) - return o - - def inference(self, c): - with torch.no_grad(): - x = torch.randn(c.shape[0], 1, self.hop_length * c.shape[-1]).to(c) - noise_scale = (self.alpha_cum**0.5).unsqueeze(1).to(c) - for n in range(len(self.alpha) - 1, -1, -1): - c1 = 1 / self.alpha[n]**0.5 - c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5 - x = c1 * (x - - c2 * self.forward(x, c, noise_scale[n]).squeeze(1)) - if n > 0: - noise = torch.randn_like(x) - sigma = ((1.0 - self.alpha_cum[n - 1]) / - (1.0 - self.alpha_cum[n]) * self.beta[n])**0.5 - x += sigma * noise - x = torch.clamp(x, -1.0, 1.0) + x = self.x_conv(spectrogram) + for layer, (film_shift, film_scale) in zip(self.ublocks, + reversed(downsampled)): + x = layer(x, film_shift, film_scale) + x = self.out_conv(x) return x + + @torch.no_grad() + def inference(self, x): + y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x) + sqrt_alpha_hat = self.noise_level.unsqueeze(1).to(x) + for n in range(len(self.alpha) - 1, -1, -1): + y_n = self.c1[n] * (y_n - + self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n]).squeeze(1)) + if n > 0: + z = torch.randn_like(y_n) + y_n += self.sigma[n - 1] * z + y_n.clamp_(-1.0, 1.0) + return y_n + + + def compute_y_n(self, y_0): + self.noise_level = self.noise_level.to(y_0) + if len(y_0.shape) == 3: + y_0 = y_0.squeeze(1) + s = torch.randint(1, self.num_steps + 1, [y_0.shape[0]]) + l_a, l_b = self.noise_level[s-1], self.noise_level[s] + noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) + noise_scale = noise_scale.unsqueeze(1) + noise = torch.randn_like(y_0) + noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise + return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] + + def compute_noise_level(self, num_steps, min_val, max_val): + beta = np.linspace(min_val, max_val, num_steps) + alpha = 1 - beta + alpha_hat = np.cumprod(alpha) + noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0) + + self.num_steps = num_steps + self.beta = torch.tensor(beta.astype(np.float32)) + self.alpha = torch.tensor(alpha.astype(np.float32)) + self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) + self.noise_level = torch.tensor(noise_level.astype(np.float32)) + + self.c1 = 1 / self.alpha**0.5 + self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5 + self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5 + + diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index d0eb0657..761b14d7 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -119,7 +119,7 @@ def setup_generator(c): in_channels=c['audio']['num_mels'], out_channels=1, x_conv_channels=c['model_params']['x_conv_channels'], - c_conv_channels=c['model_params']['c_conv_channels'], + y_conv_channels=c['model_params']['y_conv_channels'], dblock_out_channels=c['model_params']['dblock_out_channels'], ublock_out_channels=c['model_params']['ublock_out_channels'], upsample_factors=c['model_params']['upsample_factors'],