wavegrad weight norm refactoring

This commit is contained in:
erogol 2020-10-30 13:23:24 +01:00
parent 750a38f545
commit a44ef58aea
5 changed files with 156 additions and 59 deletions

View File

@ -45,6 +45,7 @@
// MODEL PARAMETERS
"generator_model": "wavegrad",
"model_params":{
"use_weight_norm": true,
"y_conv_channels":32,
"x_conv_channels":768,
"ublock_out_channels": [512, 512, 256, 128, 128],

View File

@ -39,8 +39,8 @@ class FiLM(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.encoding = PositionalEncoding(input_size)
self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1))
self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1))
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
nn.init.xavier_uniform_(self.input_conv.weight)
nn.init.xavier_uniform_(self.output_conv.weight)
@ -48,12 +48,20 @@ class FiLM(nn.Module):
nn.init.zeros_(self.output_conv.bias)
def forward(self, x, noise_scale):
x = self.input_conv(x)
x = F.leaky_relu(x, 0.2)
x = self.encoding(x, noise_scale)
shift, scale = torch.chunk(self.output_conv(x), 2, dim=1)
o = self.input_conv(x)
o = F.leaky_relu(o, 0.2)
o = self.encoding(o, noise_scale)
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
return shift, scale
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input_conv)
nn.utils.remove_weight_norm(self.output_conv)
def apply_weight_norm(self):
self.input_conv = weight_norm(self.input_conv)
self.output_conv = weight_norm(self.output_conv)
@torch.jit.script
def shif_and_scale(x, scale, shift):
@ -68,79 +76,100 @@ class UBlock(nn.Module):
assert len(dilation) == 4
self.factor = factor
self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1))
self.block2 = nn.ModuleList([
weight_norm(Conv1d(input_size,
self.res_block = Conv1d(input_size, hidden_size, 1)
self.main_block = nn.ModuleList([
Conv1d(input_size,
hidden_size,
3,
dilation=dilation[0],
padding=dilation[0])),
weight_norm(Conv1d(hidden_size,
padding=dilation[0]),
Conv1d(hidden_size,
hidden_size,
3,
dilation=dilation[1],
padding=dilation[1]))
padding=dilation[1])
])
self.block3 = nn.ModuleList([
weight_norm(Conv1d(hidden_size,
self.out_block = nn.ModuleList([
Conv1d(hidden_size,
hidden_size,
3,
dilation=dilation[2],
padding=dilation[2])),
weight_norm(Conv1d(hidden_size,
padding=dilation[2]),
Conv1d(hidden_size,
hidden_size,
3,
dilation=dilation[3],
padding=dilation[3]))
padding=dilation[3])
])
def forward(self, x, shift, scale):
o1 = F.interpolate(x, size=x.shape[-1] * self.factor)
o1 = self.block1(o1)
o2 = F.leaky_relu(x, 0.2)
o2 = F.interpolate(o2, size=x.shape[-1] * self.factor)
o2 = self.block2[0](o2)
o2 = shif_and_scale(o2, scale, shift)
o2 = F.leaky_relu(o2, 0.2)
o2 = self.block2[1](o2)
x = o1 + o2
o3 = shif_and_scale(x, scale, shift)
o3 = F.leaky_relu(o3, 0.2)
o3 = self.block3[0](o3)
o3 = shif_and_scale(o3, scale, shift)
o3 = F.leaky_relu(o3, 0.2)
o3 = self.block3[1](o3)
o = x + o3
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
res = self.res_block(x_inter)
o = F.leaky_relu(x_inter, 0.2)
o = F.interpolate(o, size=x.shape[-1] * self.factor)
o = self.main_block[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.main_block[1](o)
res2 = res + o
o = shif_and_scale(res2, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.out_block[0](o)
o = shif_and_scale(o, scale, shift)
o = F.leaky_relu(o, 0.2)
o = self.out_block[1](o)
o = o + res2
return o
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.res_block)
for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
for _, layer in enumerate(self.out_block):
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block)
for idx, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
self.main_block[idx] = weight_norm(layer)
for idx, layer in enumerate(self.out_block):
if len(layer.state_dict()) != 0:
self.out_block[idx] = weight_norm(layer)
class DBlock(nn.Module):
def __init__(self, input_size, hidden_size, factor):
super().__init__()
self.factor = factor
self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1))
self.conv = nn.ModuleList([
weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)),
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)),
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)),
self.res_block = Conv1d(input_size, hidden_size, 1)
self.main_block = nn.ModuleList([
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
])
def forward(self, x):
size = x.shape[-1] // self.factor
res = self.res_block(x)
res = F.interpolate(res, size=size)
o = F.interpolate(x, size=size)
for layer in self.main_block:
o = F.leaky_relu(o, 0.2)
o = layer(o)
return o + res
residual = self.residual_dense(x)
residual = F.interpolate(residual, size=size)
x = F.interpolate(x, size=size)
for layer in self.conv:
x = F.leaky_relu(x, 0.2)
x = layer(x)
return x + residual
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.res_block)
for _, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
nn.utils.remove_weight_norm(layer)
def apply_weight_norm(self):
self.res_block = weight_norm(self.res_block)
for idx, layer in enumerate(self.main_block):
if len(layer.state_dict()) != 0:
self.main_block[idx] = weight_norm(layer)

View File

@ -1,6 +1,7 @@
import numpy as np
import torch
from torch import nn
from torch.nn.utils import weight_norm
from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d
@ -10,6 +11,7 @@ class Wavegrad(nn.Module):
def __init__(self,
in_channels=80,
out_channels=1,
use_weight_norm=False,
y_conv_channels=32,
x_conv_channels=768,
dblock_out_channels=[128, 128, 256, 512],
@ -19,6 +21,7 @@ class Wavegrad(nn.Module):
[1, 2, 4, 8], [1, 2, 4, 8]]):
super().__init__()
self.use_weight_norm = use_weight_norm
self.hop_len = np.prod(upsample_factors)
self.noise_level = None
self.num_steps = None
@ -31,9 +34,8 @@ class Wavegrad(nn.Module):
self.sigma = None
# dblocks
self.dblocks = nn.ModuleList([
Conv1d(1, y_conv_channels, 5, padding=2),
])
self.y_conv = Conv1d(1, y_conv_channels, 5, padding=2)
self.dblocks = nn.ModuleList([])
ic = y_conv_channels
for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
self.dblocks.append(DBlock(ic, oc, df))
@ -56,15 +58,22 @@ class Wavegrad(nn.Module):
self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
self.out_conv = Conv1d(oc, out_channels, 3, padding=1)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, spectrogram, noise_scale):
downsampled = []
for film, layer in zip(self.film, self.dblocks):
shift_and_scale = []
x = self.y_conv(x)
shift_and_scale.append(self.film[0](x, noise_scale))
for film, layer in zip(self.film[1:], self.dblocks):
x = layer(x)
downsampled.append(film(x, noise_scale))
shift_and_scale.append(film(x, noise_scale))
x = self.x_conv(spectrogram)
for layer, (film_shift, film_scale) in zip(self.ublocks,
reversed(downsampled)):
reversed(shift_and_scale)):
x = layer(x, film_shift, film_scale)
x = self.out_conv(x)
return x
@ -113,3 +122,48 @@ class Wavegrad(nn.Module):
self.c1 = 1 / self.alpha**0.5
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5
def remove_weight_norm(self):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
layer.remove_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
layer.remove_weight_norm()
nn.utils.remove_weight_norm(self.x_conv)
nn.utils.remove_weight_norm(self.out_conv)
nn.utils.remove_weight_norm(self.y_conv)
def apply_weight_norm(self):
for _, layer in enumerate(self.dblocks):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
for _, layer in enumerate(self.film):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
for _, layer in enumerate(self.ublocks):
if len(layer.state_dict()) != 0:
layer.apply_weight_norm()
self.x_conv = weight_norm(self.x_conv)
self.out_conv = weight_norm(self.out_conv)
self.y_conv = weight_norm(self.y_conv)

View File

@ -118,6 +118,7 @@ def setup_generator(c):
model = MyModel(
in_channels=c['audio']['num_mels'],
out_channels=1,
use_weight_norm=c['model_params']['use_weight_norm'],
x_conv_channels=c['model_params']['x_conv_channels'],
y_conv_channels=c['model_params']['y_conv_channels'],
dblock_out_channels=c['model_params']['dblock_out_channels'],

View File

@ -32,6 +32,9 @@ def test_film():
assert scale.shape[2] == 100
assert isinstance(scale, torch.FloatTensor)
layer.apply_weight_norm()
layer.remove_weight_norm()
def test_ublock():
inp1 = torch.rand(32, 50, 100)
@ -49,6 +52,9 @@ def test_ublock():
assert o.shape[2] == 100
assert isinstance(o, torch.FloatTensor)
layer.apply_weight_norm()
layer.remove_weight_norm()
def test_dblock():
inp = torch.rand(32, 50, 130)
@ -60,6 +66,9 @@ def test_dblock():
assert o.shape[2] == 65
assert isinstance(o, torch.FloatTensor)
layer.apply_weight_norm()
layer.remove_weight_norm()
def test_wavegrad_forward():
x = torch.rand(32, 1, 20 * 300)
@ -78,3 +87,6 @@ def test_wavegrad_forward():
assert o.shape[1] == 1
assert o.shape[2] == 20 * 300
assert isinstance(o, torch.FloatTensor)
model.apply_weight_norm()
model.remove_weight_norm()