mirror of https://github.com/coqui-ai/TTS.git
wavegrad weight norm refactoring
This commit is contained in:
parent
750a38f545
commit
a44ef58aea
|
@ -45,6 +45,7 @@
|
|||
// MODEL PARAMETERS
|
||||
"generator_model": "wavegrad",
|
||||
"model_params":{
|
||||
"use_weight_norm": true,
|
||||
"y_conv_channels":32,
|
||||
"x_conv_channels":768,
|
||||
"ublock_out_channels": [512, 512, 256, 128, 128],
|
||||
|
|
|
@ -39,8 +39,8 @@ class FiLM(nn.Module):
|
|||
def __init__(self, input_size, output_size):
|
||||
super().__init__()
|
||||
self.encoding = PositionalEncoding(input_size)
|
||||
self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1))
|
||||
self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1))
|
||||
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
|
||||
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
|
||||
|
||||
nn.init.xavier_uniform_(self.input_conv.weight)
|
||||
nn.init.xavier_uniform_(self.output_conv.weight)
|
||||
|
@ -48,12 +48,20 @@ class FiLM(nn.Module):
|
|||
nn.init.zeros_(self.output_conv.bias)
|
||||
|
||||
def forward(self, x, noise_scale):
|
||||
x = self.input_conv(x)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = self.encoding(x, noise_scale)
|
||||
shift, scale = torch.chunk(self.output_conv(x), 2, dim=1)
|
||||
o = self.input_conv(x)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.encoding(o, noise_scale)
|
||||
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
|
||||
return shift, scale
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.input_conv)
|
||||
nn.utils.remove_weight_norm(self.output_conv)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.input_conv = weight_norm(self.input_conv)
|
||||
self.output_conv = weight_norm(self.output_conv)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def shif_and_scale(x, scale, shift):
|
||||
|
@ -68,79 +76,100 @@ class UBlock(nn.Module):
|
|||
assert len(dilation) == 4
|
||||
|
||||
self.factor = factor
|
||||
self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1))
|
||||
self.block2 = nn.ModuleList([
|
||||
weight_norm(Conv1d(input_size,
|
||||
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||
self.main_block = nn.ModuleList([
|
||||
Conv1d(input_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dilation=dilation[0],
|
||||
padding=dilation[0])),
|
||||
weight_norm(Conv1d(hidden_size,
|
||||
padding=dilation[0]),
|
||||
Conv1d(hidden_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dilation=dilation[1],
|
||||
padding=dilation[1]))
|
||||
padding=dilation[1])
|
||||
])
|
||||
self.block3 = nn.ModuleList([
|
||||
weight_norm(Conv1d(hidden_size,
|
||||
self.out_block = nn.ModuleList([
|
||||
Conv1d(hidden_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dilation=dilation[2],
|
||||
padding=dilation[2])),
|
||||
weight_norm(Conv1d(hidden_size,
|
||||
padding=dilation[2]),
|
||||
Conv1d(hidden_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dilation=dilation[3],
|
||||
padding=dilation[3]))
|
||||
padding=dilation[3])
|
||||
])
|
||||
|
||||
def forward(self, x, shift, scale):
|
||||
o1 = F.interpolate(x, size=x.shape[-1] * self.factor)
|
||||
o1 = self.block1(o1)
|
||||
|
||||
o2 = F.leaky_relu(x, 0.2)
|
||||
o2 = F.interpolate(o2, size=x.shape[-1] * self.factor)
|
||||
o2 = self.block2[0](o2)
|
||||
o2 = shif_and_scale(o2, scale, shift)
|
||||
o2 = F.leaky_relu(o2, 0.2)
|
||||
o2 = self.block2[1](o2)
|
||||
|
||||
x = o1 + o2
|
||||
|
||||
o3 = shif_and_scale(x, scale, shift)
|
||||
o3 = F.leaky_relu(o3, 0.2)
|
||||
o3 = self.block3[0](o3)
|
||||
|
||||
o3 = shif_and_scale(o3, scale, shift)
|
||||
o3 = F.leaky_relu(o3, 0.2)
|
||||
o3 = self.block3[1](o3)
|
||||
|
||||
o = x + o3
|
||||
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
|
||||
res = self.res_block(x_inter)
|
||||
o = F.leaky_relu(x_inter, 0.2)
|
||||
o = F.interpolate(o, size=x.shape[-1] * self.factor)
|
||||
o = self.main_block[0](o)
|
||||
o = shif_and_scale(o, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.main_block[1](o)
|
||||
res2 = res + o
|
||||
o = shif_and_scale(res2, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.out_block[0](o)
|
||||
o = shif_and_scale(o, scale, shift)
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = self.out_block[1](o)
|
||||
o = o + res2
|
||||
return o
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.res_block)
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
for _, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
for idx, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.main_block[idx] = weight_norm(layer)
|
||||
for idx, layer in enumerate(self.out_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.out_block[idx] = weight_norm(layer)
|
||||
|
||||
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, factor):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1))
|
||||
self.conv = nn.ModuleList([
|
||||
weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)),
|
||||
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)),
|
||||
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)),
|
||||
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||
self.main_block = nn.ModuleList([
|
||||
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
size = x.shape[-1] // self.factor
|
||||
res = self.res_block(x)
|
||||
res = F.interpolate(res, size=size)
|
||||
o = F.interpolate(x, size=size)
|
||||
for layer in self.main_block:
|
||||
o = F.leaky_relu(o, 0.2)
|
||||
o = layer(o)
|
||||
return o + res
|
||||
|
||||
residual = self.residual_dense(x)
|
||||
residual = F.interpolate(residual, size=size)
|
||||
|
||||
x = F.interpolate(x, size=size)
|
||||
for layer in self.conv:
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = layer(x)
|
||||
|
||||
return x + residual
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.res_block)
|
||||
for _, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
self.res_block = weight_norm(self.res_block)
|
||||
for idx, layer in enumerate(self.main_block):
|
||||
if len(layer.state_dict()) != 0:
|
||||
self.main_block[idx] = weight_norm(layer)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d
|
||||
|
||||
|
@ -10,6 +11,7 @@ class Wavegrad(nn.Module):
|
|||
def __init__(self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
use_weight_norm=False,
|
||||
y_conv_channels=32,
|
||||
x_conv_channels=768,
|
||||
dblock_out_channels=[128, 128, 256, 512],
|
||||
|
@ -19,6 +21,7 @@ class Wavegrad(nn.Module):
|
|||
[1, 2, 4, 8], [1, 2, 4, 8]]):
|
||||
super().__init__()
|
||||
|
||||
self.use_weight_norm = use_weight_norm
|
||||
self.hop_len = np.prod(upsample_factors)
|
||||
self.noise_level = None
|
||||
self.num_steps = None
|
||||
|
@ -31,9 +34,8 @@ class Wavegrad(nn.Module):
|
|||
self.sigma = None
|
||||
|
||||
# dblocks
|
||||
self.dblocks = nn.ModuleList([
|
||||
Conv1d(1, y_conv_channels, 5, padding=2),
|
||||
])
|
||||
self.y_conv = Conv1d(1, y_conv_channels, 5, padding=2)
|
||||
self.dblocks = nn.ModuleList([])
|
||||
ic = y_conv_channels
|
||||
for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
|
||||
self.dblocks.append(DBlock(ic, oc, df))
|
||||
|
@ -56,15 +58,22 @@ class Wavegrad(nn.Module):
|
|||
self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
|
||||
self.out_conv = Conv1d(oc, out_channels, 3, padding=1)
|
||||
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x, spectrogram, noise_scale):
|
||||
downsampled = []
|
||||
for film, layer in zip(self.film, self.dblocks):
|
||||
shift_and_scale = []
|
||||
|
||||
x = self.y_conv(x)
|
||||
shift_and_scale.append(self.film[0](x, noise_scale))
|
||||
|
||||
for film, layer in zip(self.film[1:], self.dblocks):
|
||||
x = layer(x)
|
||||
downsampled.append(film(x, noise_scale))
|
||||
shift_and_scale.append(film(x, noise_scale))
|
||||
|
||||
x = self.x_conv(spectrogram)
|
||||
for layer, (film_shift, film_scale) in zip(self.ublocks,
|
||||
reversed(downsampled)):
|
||||
reversed(shift_and_scale)):
|
||||
x = layer(x, film_shift, film_scale)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
@ -113,3 +122,48 @@ class Wavegrad(nn.Module):
|
|||
self.c1 = 1 / self.alpha**0.5
|
||||
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5
|
||||
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
nn.utils.remove_weight_norm(self.x_conv)
|
||||
nn.utils.remove_weight_norm(self.out_conv)
|
||||
nn.utils.remove_weight_norm(self.y_conv)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
for _, layer in enumerate(self.dblocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
for _, layer in enumerate(self.film):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
|
||||
for _, layer in enumerate(self.ublocks):
|
||||
if len(layer.state_dict()) != 0:
|
||||
layer.apply_weight_norm()
|
||||
|
||||
self.x_conv = weight_norm(self.x_conv)
|
||||
self.out_conv = weight_norm(self.out_conv)
|
||||
self.y_conv = weight_norm(self.y_conv)
|
||||
|
|
|
@ -118,6 +118,7 @@ def setup_generator(c):
|
|||
model = MyModel(
|
||||
in_channels=c['audio']['num_mels'],
|
||||
out_channels=1,
|
||||
use_weight_norm=c['model_params']['use_weight_norm'],
|
||||
x_conv_channels=c['model_params']['x_conv_channels'],
|
||||
y_conv_channels=c['model_params']['y_conv_channels'],
|
||||
dblock_out_channels=c['model_params']['dblock_out_channels'],
|
||||
|
|
|
@ -32,6 +32,9 @@ def test_film():
|
|||
assert scale.shape[2] == 100
|
||||
assert isinstance(scale, torch.FloatTensor)
|
||||
|
||||
layer.apply_weight_norm()
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
def test_ublock():
|
||||
inp1 = torch.rand(32, 50, 100)
|
||||
|
@ -49,6 +52,9 @@ def test_ublock():
|
|||
assert o.shape[2] == 100
|
||||
assert isinstance(o, torch.FloatTensor)
|
||||
|
||||
layer.apply_weight_norm()
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
def test_dblock():
|
||||
inp = torch.rand(32, 50, 130)
|
||||
|
@ -60,6 +66,9 @@ def test_dblock():
|
|||
assert o.shape[2] == 65
|
||||
assert isinstance(o, torch.FloatTensor)
|
||||
|
||||
layer.apply_weight_norm()
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
def test_wavegrad_forward():
|
||||
x = torch.rand(32, 1, 20 * 300)
|
||||
|
@ -78,3 +87,6 @@ def test_wavegrad_forward():
|
|||
assert o.shape[1] == 1
|
||||
assert o.shape[2] == 20 * 300
|
||||
assert isinstance(o, torch.FloatTensor)
|
||||
|
||||
model.apply_weight_norm()
|
||||
model.remove_weight_norm()
|
||||
|
|
Loading…
Reference in New Issue