mirror of https://github.com/coqui-ai/TTS.git
wavegrad weight norm refactoring
This commit is contained in:
parent
750a38f545
commit
a44ef58aea
|
@ -45,6 +45,7 @@
|
||||||
// MODEL PARAMETERS
|
// MODEL PARAMETERS
|
||||||
"generator_model": "wavegrad",
|
"generator_model": "wavegrad",
|
||||||
"model_params":{
|
"model_params":{
|
||||||
|
"use_weight_norm": true,
|
||||||
"y_conv_channels":32,
|
"y_conv_channels":32,
|
||||||
"x_conv_channels":768,
|
"x_conv_channels":768,
|
||||||
"ublock_out_channels": [512, 512, 256, 128, 128],
|
"ublock_out_channels": [512, 512, 256, 128, 128],
|
||||||
|
|
|
@ -39,8 +39,8 @@ class FiLM(nn.Module):
|
||||||
def __init__(self, input_size, output_size):
|
def __init__(self, input_size, output_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoding = PositionalEncoding(input_size)
|
self.encoding = PositionalEncoding(input_size)
|
||||||
self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1))
|
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
|
||||||
self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1))
|
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
|
||||||
|
|
||||||
nn.init.xavier_uniform_(self.input_conv.weight)
|
nn.init.xavier_uniform_(self.input_conv.weight)
|
||||||
nn.init.xavier_uniform_(self.output_conv.weight)
|
nn.init.xavier_uniform_(self.output_conv.weight)
|
||||||
|
@ -48,12 +48,20 @@ class FiLM(nn.Module):
|
||||||
nn.init.zeros_(self.output_conv.bias)
|
nn.init.zeros_(self.output_conv.bias)
|
||||||
|
|
||||||
def forward(self, x, noise_scale):
|
def forward(self, x, noise_scale):
|
||||||
x = self.input_conv(x)
|
o = self.input_conv(x)
|
||||||
x = F.leaky_relu(x, 0.2)
|
o = F.leaky_relu(o, 0.2)
|
||||||
x = self.encoding(x, noise_scale)
|
o = self.encoding(o, noise_scale)
|
||||||
shift, scale = torch.chunk(self.output_conv(x), 2, dim=1)
|
shift, scale = torch.chunk(self.output_conv(o), 2, dim=1)
|
||||||
return shift, scale
|
return shift, scale
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
nn.utils.remove_weight_norm(self.input_conv)
|
||||||
|
nn.utils.remove_weight_norm(self.output_conv)
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
self.input_conv = weight_norm(self.input_conv)
|
||||||
|
self.output_conv = weight_norm(self.output_conv)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def shif_and_scale(x, scale, shift):
|
def shif_and_scale(x, scale, shift):
|
||||||
|
@ -68,79 +76,100 @@ class UBlock(nn.Module):
|
||||||
assert len(dilation) == 4
|
assert len(dilation) == 4
|
||||||
|
|
||||||
self.factor = factor
|
self.factor = factor
|
||||||
self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1))
|
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||||
self.block2 = nn.ModuleList([
|
self.main_block = nn.ModuleList([
|
||||||
weight_norm(Conv1d(input_size,
|
Conv1d(input_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3,
|
3,
|
||||||
dilation=dilation[0],
|
dilation=dilation[0],
|
||||||
padding=dilation[0])),
|
padding=dilation[0]),
|
||||||
weight_norm(Conv1d(hidden_size,
|
Conv1d(hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3,
|
3,
|
||||||
dilation=dilation[1],
|
dilation=dilation[1],
|
||||||
padding=dilation[1]))
|
padding=dilation[1])
|
||||||
])
|
])
|
||||||
self.block3 = nn.ModuleList([
|
self.out_block = nn.ModuleList([
|
||||||
weight_norm(Conv1d(hidden_size,
|
Conv1d(hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3,
|
3,
|
||||||
dilation=dilation[2],
|
dilation=dilation[2],
|
||||||
padding=dilation[2])),
|
padding=dilation[2]),
|
||||||
weight_norm(Conv1d(hidden_size,
|
Conv1d(hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3,
|
3,
|
||||||
dilation=dilation[3],
|
dilation=dilation[3],
|
||||||
padding=dilation[3]))
|
padding=dilation[3])
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x, shift, scale):
|
def forward(self, x, shift, scale):
|
||||||
o1 = F.interpolate(x, size=x.shape[-1] * self.factor)
|
x_inter = F.interpolate(x, size=x.shape[-1] * self.factor)
|
||||||
o1 = self.block1(o1)
|
res = self.res_block(x_inter)
|
||||||
|
o = F.leaky_relu(x_inter, 0.2)
|
||||||
o2 = F.leaky_relu(x, 0.2)
|
o = F.interpolate(o, size=x.shape[-1] * self.factor)
|
||||||
o2 = F.interpolate(o2, size=x.shape[-1] * self.factor)
|
o = self.main_block[0](o)
|
||||||
o2 = self.block2[0](o2)
|
o = shif_and_scale(o, scale, shift)
|
||||||
o2 = shif_and_scale(o2, scale, shift)
|
o = F.leaky_relu(o, 0.2)
|
||||||
o2 = F.leaky_relu(o2, 0.2)
|
o = self.main_block[1](o)
|
||||||
o2 = self.block2[1](o2)
|
res2 = res + o
|
||||||
|
o = shif_and_scale(res2, scale, shift)
|
||||||
x = o1 + o2
|
o = F.leaky_relu(o, 0.2)
|
||||||
|
o = self.out_block[0](o)
|
||||||
o3 = shif_and_scale(x, scale, shift)
|
o = shif_and_scale(o, scale, shift)
|
||||||
o3 = F.leaky_relu(o3, 0.2)
|
o = F.leaky_relu(o, 0.2)
|
||||||
o3 = self.block3[0](o3)
|
o = self.out_block[1](o)
|
||||||
|
o = o + res2
|
||||||
o3 = shif_and_scale(o3, scale, shift)
|
|
||||||
o3 = F.leaky_relu(o3, 0.2)
|
|
||||||
o3 = self.block3[1](o3)
|
|
||||||
|
|
||||||
o = x + o3
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
nn.utils.remove_weight_norm(self.res_block)
|
||||||
|
for _, layer in enumerate(self.main_block):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
for _, layer in enumerate(self.out_block):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
self.res_block = weight_norm(self.res_block)
|
||||||
|
for idx, layer in enumerate(self.main_block):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
self.main_block[idx] = weight_norm(layer)
|
||||||
|
for idx, layer in enumerate(self.out_block):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
self.out_block[idx] = weight_norm(layer)
|
||||||
|
|
||||||
|
|
||||||
class DBlock(nn.Module):
|
class DBlock(nn.Module):
|
||||||
def __init__(self, input_size, hidden_size, factor):
|
def __init__(self, input_size, hidden_size, factor):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.factor = factor
|
self.factor = factor
|
||||||
self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1))
|
self.res_block = Conv1d(input_size, hidden_size, 1)
|
||||||
self.conv = nn.ModuleList([
|
self.main_block = nn.ModuleList([
|
||||||
weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)),
|
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
|
||||||
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)),
|
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
|
||||||
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)),
|
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
size = x.shape[-1] // self.factor
|
size = x.shape[-1] // self.factor
|
||||||
|
res = self.res_block(x)
|
||||||
|
res = F.interpolate(res, size=size)
|
||||||
|
o = F.interpolate(x, size=size)
|
||||||
|
for layer in self.main_block:
|
||||||
|
o = F.leaky_relu(o, 0.2)
|
||||||
|
o = layer(o)
|
||||||
|
return o + res
|
||||||
|
|
||||||
residual = self.residual_dense(x)
|
def remove_weight_norm(self):
|
||||||
residual = F.interpolate(residual, size=size)
|
nn.utils.remove_weight_norm(self.res_block)
|
||||||
|
for _, layer in enumerate(self.main_block):
|
||||||
x = F.interpolate(x, size=size)
|
if len(layer.state_dict()) != 0:
|
||||||
for layer in self.conv:
|
nn.utils.remove_weight_norm(layer)
|
||||||
x = F.leaky_relu(x, 0.2)
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
return x + residual
|
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
self.res_block = weight_norm(self.res_block)
|
||||||
|
for idx, layer in enumerate(self.main_block):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
self.main_block[idx] = weight_norm(layer)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d
|
from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d
|
||||||
|
|
||||||
|
@ -10,6 +11,7 @@ class Wavegrad(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels=80,
|
in_channels=80,
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
|
use_weight_norm=False,
|
||||||
y_conv_channels=32,
|
y_conv_channels=32,
|
||||||
x_conv_channels=768,
|
x_conv_channels=768,
|
||||||
dblock_out_channels=[128, 128, 256, 512],
|
dblock_out_channels=[128, 128, 256, 512],
|
||||||
|
@ -19,6 +21,7 @@ class Wavegrad(nn.Module):
|
||||||
[1, 2, 4, 8], [1, 2, 4, 8]]):
|
[1, 2, 4, 8], [1, 2, 4, 8]]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.use_weight_norm = use_weight_norm
|
||||||
self.hop_len = np.prod(upsample_factors)
|
self.hop_len = np.prod(upsample_factors)
|
||||||
self.noise_level = None
|
self.noise_level = None
|
||||||
self.num_steps = None
|
self.num_steps = None
|
||||||
|
@ -31,9 +34,8 @@ class Wavegrad(nn.Module):
|
||||||
self.sigma = None
|
self.sigma = None
|
||||||
|
|
||||||
# dblocks
|
# dblocks
|
||||||
self.dblocks = nn.ModuleList([
|
self.y_conv = Conv1d(1, y_conv_channels, 5, padding=2)
|
||||||
Conv1d(1, y_conv_channels, 5, padding=2),
|
self.dblocks = nn.ModuleList([])
|
||||||
])
|
|
||||||
ic = y_conv_channels
|
ic = y_conv_channels
|
||||||
for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
|
for oc, df in zip(dblock_out_channels, reversed(upsample_factors)):
|
||||||
self.dblocks.append(DBlock(ic, oc, df))
|
self.dblocks.append(DBlock(ic, oc, df))
|
||||||
|
@ -56,15 +58,22 @@ class Wavegrad(nn.Module):
|
||||||
self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
|
self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1)
|
||||||
self.out_conv = Conv1d(oc, out_channels, 3, padding=1)
|
self.out_conv = Conv1d(oc, out_channels, 3, padding=1)
|
||||||
|
|
||||||
|
if use_weight_norm:
|
||||||
|
self.apply_weight_norm()
|
||||||
|
|
||||||
def forward(self, x, spectrogram, noise_scale):
|
def forward(self, x, spectrogram, noise_scale):
|
||||||
downsampled = []
|
shift_and_scale = []
|
||||||
for film, layer in zip(self.film, self.dblocks):
|
|
||||||
|
x = self.y_conv(x)
|
||||||
|
shift_and_scale.append(self.film[0](x, noise_scale))
|
||||||
|
|
||||||
|
for film, layer in zip(self.film[1:], self.dblocks):
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
downsampled.append(film(x, noise_scale))
|
shift_and_scale.append(film(x, noise_scale))
|
||||||
|
|
||||||
x = self.x_conv(spectrogram)
|
x = self.x_conv(spectrogram)
|
||||||
for layer, (film_shift, film_scale) in zip(self.ublocks,
|
for layer, (film_shift, film_scale) in zip(self.ublocks,
|
||||||
reversed(downsampled)):
|
reversed(shift_and_scale)):
|
||||||
x = layer(x, film_shift, film_scale)
|
x = layer(x, film_shift, film_scale)
|
||||||
x = self.out_conv(x)
|
x = self.out_conv(x)
|
||||||
return x
|
return x
|
||||||
|
@ -113,3 +122,48 @@ class Wavegrad(nn.Module):
|
||||||
self.c1 = 1 / self.alpha**0.5
|
self.c1 = 1 / self.alpha**0.5
|
||||||
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5
|
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5
|
||||||
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5
|
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for _, layer in enumerate(self.dblocks):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except ValueError:
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
for _, layer in enumerate(self.film):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except ValueError:
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
|
||||||
|
for _, layer in enumerate(self.ublocks):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
try:
|
||||||
|
nn.utils.remove_weight_norm(layer)
|
||||||
|
except ValueError:
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
nn.utils.remove_weight_norm(self.x_conv)
|
||||||
|
nn.utils.remove_weight_norm(self.out_conv)
|
||||||
|
nn.utils.remove_weight_norm(self.y_conv)
|
||||||
|
|
||||||
|
def apply_weight_norm(self):
|
||||||
|
for _, layer in enumerate(self.dblocks):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
layer.apply_weight_norm()
|
||||||
|
|
||||||
|
for _, layer in enumerate(self.film):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
layer.apply_weight_norm()
|
||||||
|
|
||||||
|
|
||||||
|
for _, layer in enumerate(self.ublocks):
|
||||||
|
if len(layer.state_dict()) != 0:
|
||||||
|
layer.apply_weight_norm()
|
||||||
|
|
||||||
|
self.x_conv = weight_norm(self.x_conv)
|
||||||
|
self.out_conv = weight_norm(self.out_conv)
|
||||||
|
self.y_conv = weight_norm(self.y_conv)
|
||||||
|
|
|
@ -118,6 +118,7 @@ def setup_generator(c):
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
in_channels=c['audio']['num_mels'],
|
in_channels=c['audio']['num_mels'],
|
||||||
out_channels=1,
|
out_channels=1,
|
||||||
|
use_weight_norm=c['model_params']['use_weight_norm'],
|
||||||
x_conv_channels=c['model_params']['x_conv_channels'],
|
x_conv_channels=c['model_params']['x_conv_channels'],
|
||||||
y_conv_channels=c['model_params']['y_conv_channels'],
|
y_conv_channels=c['model_params']['y_conv_channels'],
|
||||||
dblock_out_channels=c['model_params']['dblock_out_channels'],
|
dblock_out_channels=c['model_params']['dblock_out_channels'],
|
||||||
|
|
|
@ -32,6 +32,9 @@ def test_film():
|
||||||
assert scale.shape[2] == 100
|
assert scale.shape[2] == 100
|
||||||
assert isinstance(scale, torch.FloatTensor)
|
assert isinstance(scale, torch.FloatTensor)
|
||||||
|
|
||||||
|
layer.apply_weight_norm()
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
|
||||||
def test_ublock():
|
def test_ublock():
|
||||||
inp1 = torch.rand(32, 50, 100)
|
inp1 = torch.rand(32, 50, 100)
|
||||||
|
@ -49,6 +52,9 @@ def test_ublock():
|
||||||
assert o.shape[2] == 100
|
assert o.shape[2] == 100
|
||||||
assert isinstance(o, torch.FloatTensor)
|
assert isinstance(o, torch.FloatTensor)
|
||||||
|
|
||||||
|
layer.apply_weight_norm()
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
|
||||||
def test_dblock():
|
def test_dblock():
|
||||||
inp = torch.rand(32, 50, 130)
|
inp = torch.rand(32, 50, 130)
|
||||||
|
@ -60,6 +66,9 @@ def test_dblock():
|
||||||
assert o.shape[2] == 65
|
assert o.shape[2] == 65
|
||||||
assert isinstance(o, torch.FloatTensor)
|
assert isinstance(o, torch.FloatTensor)
|
||||||
|
|
||||||
|
layer.apply_weight_norm()
|
||||||
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
|
||||||
def test_wavegrad_forward():
|
def test_wavegrad_forward():
|
||||||
x = torch.rand(32, 1, 20 * 300)
|
x = torch.rand(32, 1, 20 * 300)
|
||||||
|
@ -78,3 +87,6 @@ def test_wavegrad_forward():
|
||||||
assert o.shape[1] == 1
|
assert o.shape[1] == 1
|
||||||
assert o.shape[2] == 20 * 300
|
assert o.shape[2] == 20 * 300
|
||||||
assert isinstance(o, torch.FloatTensor)
|
assert isinstance(o, torch.FloatTensor)
|
||||||
|
|
||||||
|
model.apply_weight_norm()
|
||||||
|
model.remove_weight_norm()
|
||||||
|
|
Loading…
Reference in New Issue