wavegrad model and layers refactoring

This commit is contained in:
erogol 2020-10-26 16:47:18 +01:00
parent dc2825dfb2
commit b76a0be97a
2 changed files with 77 additions and 96 deletions

View File

@ -2,7 +2,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from ..layers.wavegrad import DBlock, FiLM, UBlock from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d
class Wavegrad(nn.Module): class Wavegrad(nn.Module):
@ -10,8 +10,8 @@ class Wavegrad(nn.Module):
def __init__(self, def __init__(self,
in_channels=80, in_channels=80,
out_channels=1, out_channels=1,
x_conv_channels=32, y_conv_channels=32,
c_conv_channels=768, x_conv_channels=768,
dblock_out_channels=[128, 128, 256, 512], dblock_out_channels=[128, 128, 256, 512],
ublock_out_channels=[512, 512, 256, 128, 128], ublock_out_channels=[512, 512, 256, 128, 128],
upsample_factors=[5, 5, 3, 2, 2], upsample_factors=[5, 5, 3, 2, 2],
@ -19,106 +19,87 @@ class Wavegrad(nn.Module):
[1, 2, 4, 8], [1, 2, 4, 8]]): [1, 2, 4, 8], [1, 2, 4, 8]]):
super().__init__() super().__init__()
assert len(upsample_factors) == len(upsample_dilations) self.hop_len = np.prod(upsample_factors)
assert len(upsample_factors) == len(ublock_out_channels)
# setup up-down sampling parameters # dblocks
self.hop_length = np.prod(upsample_factors)
self.upsample_factors = upsample_factors
self.downsample_factors = upsample_factors[::-1][:-1]
### define DBlocks, FiLM layers ###
self.dblocks = nn.ModuleList([ self.dblocks = nn.ModuleList([
nn.Conv1d(out_channels, x_conv_channels, 5, padding=2), Conv1d(1, y_conv_channels, 5, padding=2),
]) ])
ic = x_conv_channels ic = y_conv_channels
self.films = nn.ModuleList([]) for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
for oc, df in zip(dblock_out_channels, self.downsample_factors): self.dblocks.append(DBlock(ic, oc, df))
# print('dblock(', ic, ', ', oc, ', ', df, ")")
layer = DBlock(ic, oc, df)
self.dblocks.append(layer)
# print('film(', ic, ', ', oc,")")
layer = FiLM(ic, oc)
self.films.append(layer)
ic = oc ic = oc
# last FiLM block
# print('film(', ic, ', ', dblock_out_channels[-1],")")
self.films.append(FiLM(ic, dblock_out_channels[-1]))
### define UBlocks ### # film
self.c_conv = nn.Conv1d(in_channels, c_conv_channels, 3, padding=1) self.film = nn.ModuleList([])
ic = y_conv_channels
for oc in reversed(ublock_out_channels):
self.film.append(FiLM(ic, oc))
ic = oc
# ublocks
self.ublocks = nn.ModuleList([]) self.ublocks = nn.ModuleList([])
ic = c_conv_channels ic = x_conv_channels
for idx, (oc, uf) in enumerate(zip(ublock_out_channels, self.upsample_factors)): for oc, uf, ud in zip(ublock_out_channels, upsample_factors, upsample_dilations):
# print('ublock(', ic, ', ', oc, ', ', uf, ")") self.ublocks.append(UBlock(ic, oc, uf, ud))
layer = UBlock(ic, oc, uf, upsample_dilations[idx])
self.ublocks.append(layer)
ic = oc ic = oc
# define last layer self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
# print(ic, 'last_conv--', out_channels) self.out_conv = Conv1d(oc, out_channels, 3, padding=1)
self.last_conv = nn.Conv1d(ic, out_channels, 3, padding=1)
# inference time noise schedule params def forward(self, x, spectrogram, noise_scale):
self.S = 1000 downsampled = []
self.init_noise_schedule(self.S) for film, layer in zip(self.film, self.dblocks):
x = layer(x)
downsampled.append(film(x, noise_scale))
x = self.x_conv(spectrogram)
def init_noise_schedule(self, num_iter, min_val=1e-6, max_val=0.01): for layer, (film_shift, film_scale) in zip(self.ublocks,
"""compute noise schedule parameters""" reversed(downsampled)):
device = self.last_conv.weight.device x = layer(x, film_shift, film_scale)
beta = torch.linspace(min_val, max_val, num_iter).to(device) x = self.out_conv(x)
alpha = 1 - beta
alpha_cum = alpha.cumprod(dim=0)
noise_level = torch.cat([torch.FloatTensor([1]).to(device), alpha_cum ** 0.5])
self.register_buffer('beta', beta)
self.register_buffer('alpha', alpha)
self.register_buffer('alpha_cum', alpha_cum)
self.register_buffer('noise_level', noise_level)
def compute_noisy_x(self, x):
B = x.shape[0]
if len(x.shape) == 3:
x = x.squeeze(1)
s = torch.randint(1, self.S + 1, [B]).to(x).long()
l_a, l_b = self.noise_level[s-1], self.noise_level[s]
noise_scale = l_a + torch.rand(B).to(x) * (l_b - l_a)
noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(x)
noisy_x = noise_scale * x + (1.0 - noise_scale**2)**0.5 * noise
return noise.unsqueeze(1), noisy_x.unsqueeze(1), noise_scale[:, 0]
def forward(self, x, c, noise_scale):
assert len(c.shape) == 3 # B, C, T
assert len(x.shape) == 3 # B, 1, T
o = x
shift_and_scales = []
for film, dblock in zip(self.films, self.dblocks):
o = dblock(o)
shift_and_scales.append(film(o, noise_scale))
o = self.c_conv(c)
for ublock, (film_shift, film_scale) in zip(self.ublocks,
reversed(shift_and_scales)):
o = ublock(o, film_shift, film_scale)
o = self.last_conv(o)
return o
def inference(self, c):
with torch.no_grad():
x = torch.randn(c.shape[0], 1, self.hop_length * c.shape[-1]).to(c)
noise_scale = (self.alpha_cum**0.5).unsqueeze(1).to(c)
for n in range(len(self.alpha) - 1, -1, -1):
c1 = 1 / self.alpha[n]**0.5
c2 = (1 - self.alpha[n]) / (1 - self.alpha_cum[n])**0.5
x = c1 * (x -
c2 * self.forward(x, c, noise_scale[n]).squeeze(1))
if n > 0:
noise = torch.randn_like(x)
sigma = ((1.0 - self.alpha_cum[n - 1]) /
(1.0 - self.alpha_cum[n]) * self.beta[n])**0.5
x += sigma * noise
x = torch.clamp(x, -1.0, 1.0)
return x return x
@torch.no_grad()
def inference(self, x):
y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1], dtype=torch.float32).to(x)
sqrt_alpha_hat = self.noise_level.unsqueeze(1).to(x)
for n in range(len(self.alpha) - 1, -1, -1):
y_n = self.c1[n] * (y_n -
self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n]).squeeze(1))
if n > 0:
z = torch.randn_like(y_n)
y_n += self.sigma[n - 1] * z
y_n.clamp_(-1.0, 1.0)
return y_n
def compute_y_n(self, y_0):
self.noise_level = self.noise_level.to(y_0)
if len(y_0.shape) == 3:
y_0 = y_0.squeeze(1)
s = torch.randint(1, self.num_steps + 1, [y_0.shape[0]])
l_a, l_b = self.noise_level[s-1], self.noise_level[s]
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(y_0)
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
def compute_noise_level(self, num_steps, min_val, max_val):
beta = np.linspace(min_val, max_val, num_steps)
alpha = 1 - beta
alpha_hat = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0)
self.num_steps = num_steps
self.beta = torch.tensor(beta.astype(np.float32))
self.alpha = torch.tensor(alpha.astype(np.float32))
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
self.noise_level = torch.tensor(noise_level.astype(np.float32))
self.c1 = 1 / self.alpha**0.5
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5

View File

@ -119,7 +119,7 @@ def setup_generator(c):
in_channels=c['audio']['num_mels'], in_channels=c['audio']['num_mels'],
out_channels=1, out_channels=1,
x_conv_channels=c['model_params']['x_conv_channels'], x_conv_channels=c['model_params']['x_conv_channels'],
c_conv_channels=c['model_params']['c_conv_channels'], y_conv_channels=c['model_params']['y_conv_channels'],
dblock_out_channels=c['model_params']['dblock_out_channels'], dblock_out_channels=c['model_params']['dblock_out_channels'],
ublock_out_channels=c['model_params']['ublock_out_channels'], ublock_out_channels=c['model_params']['ublock_out_channels'],
upsample_factors=c['model_params']['upsample_factors'], upsample_factors=c['model_params']['upsample_factors'],