From a44ef58aea7820aa3ac9ae064e8b45d15653d6a5 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 30 Oct 2020 13:23:24 +0100 Subject: [PATCH] wavegrad weight norm refactoring --- TTS/vocoder/configs/wavegrad_libritts.json | 1 + TTS/vocoder/layers/wavegrad.py | 133 +++++++++++++-------- TTS/vocoder/models/wavegrad.py | 68 +++++++++-- TTS/vocoder/utils/generic_utils.py | 1 + tests/test_wavegrad_layers.py | 12 ++ 5 files changed, 156 insertions(+), 59 deletions(-) diff --git a/TTS/vocoder/configs/wavegrad_libritts.json b/TTS/vocoder/configs/wavegrad_libritts.json index 5720a482..57c26709 100644 --- a/TTS/vocoder/configs/wavegrad_libritts.json +++ b/TTS/vocoder/configs/wavegrad_libritts.json @@ -45,6 +45,7 @@ // MODEL PARAMETERS "generator_model": "wavegrad", "model_params":{ + "use_weight_norm": true, "y_conv_channels":32, "x_conv_channels":768, "ublock_out_channels": [512, 512, 256, 128, 128], diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index 2c781fd6..d09b4950 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -39,8 +39,8 @@ class FiLM(nn.Module): def __init__(self, input_size, output_size): super().__init__() self.encoding = PositionalEncoding(input_size) - self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1)) - self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1)) + self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) + self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1) nn.init.xavier_uniform_(self.input_conv.weight) nn.init.xavier_uniform_(self.output_conv.weight) @@ -48,12 +48,20 @@ class FiLM(nn.Module): nn.init.zeros_(self.output_conv.bias) def forward(self, x, noise_scale): - x = self.input_conv(x) - x = F.leaky_relu(x, 0.2) - x = self.encoding(x, noise_scale) - shift, scale = torch.chunk(self.output_conv(x), 2, dim=1) + o = self.input_conv(x) + o = F.leaky_relu(o, 0.2) + o = self.encoding(o, noise_scale) + shift, scale = torch.chunk(self.output_conv(o), 2, dim=1) return shift, scale + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.input_conv) + nn.utils.remove_weight_norm(self.output_conv) + + def apply_weight_norm(self): + self.input_conv = weight_norm(self.input_conv) + self.output_conv = weight_norm(self.output_conv) + @torch.jit.script def shif_and_scale(x, scale, shift): @@ -68,79 +76,100 @@ class UBlock(nn.Module): assert len(dilation) == 4 self.factor = factor - self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1)) - self.block2 = nn.ModuleList([ - weight_norm(Conv1d(input_size, + self.res_block = Conv1d(input_size, hidden_size, 1) + self.main_block = nn.ModuleList([ + Conv1d(input_size, hidden_size, 3, dilation=dilation[0], - padding=dilation[0])), - weight_norm(Conv1d(hidden_size, + padding=dilation[0]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], - padding=dilation[1])) + padding=dilation[1]) ]) - self.block3 = nn.ModuleList([ - weight_norm(Conv1d(hidden_size, + self.out_block = nn.ModuleList([ + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], - padding=dilation[2])), - weight_norm(Conv1d(hidden_size, + padding=dilation[2]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], - padding=dilation[3])) + padding=dilation[3]) ]) def forward(self, x, shift, scale): - o1 = F.interpolate(x, size=x.shape[-1] * self.factor) - o1 = self.block1(o1) - - o2 = F.leaky_relu(x, 0.2) - o2 = F.interpolate(o2, size=x.shape[-1] * self.factor) - o2 = self.block2[0](o2) - o2 = shif_and_scale(o2, scale, shift) - o2 = F.leaky_relu(o2, 0.2) - o2 = self.block2[1](o2) - - x = o1 + o2 - - o3 = shif_and_scale(x, scale, shift) - o3 = F.leaky_relu(o3, 0.2) - o3 = self.block3[0](o3) - - o3 = shif_and_scale(o3, scale, shift) - o3 = F.leaky_relu(o3, 0.2) - o3 = self.block3[1](o3) - - o = x + o3 + x_inter = F.interpolate(x, size=x.shape[-1] * self.factor) + res = self.res_block(x_inter) + o = F.leaky_relu(x_inter, 0.2) + o = F.interpolate(o, size=x.shape[-1] * self.factor) + o = self.main_block[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.main_block[1](o) + res2 = res + o + o = shif_and_scale(res2, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.out_block[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.out_block[1](o) + o = o + res2 return o + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.res_block) + for _, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + for _, layer in enumerate(self.out_block): + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + + def apply_weight_norm(self): + self.res_block = weight_norm(self.res_block) + for idx, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + self.main_block[idx] = weight_norm(layer) + for idx, layer in enumerate(self.out_block): + if len(layer.state_dict()) != 0: + self.out_block[idx] = weight_norm(layer) + class DBlock(nn.Module): def __init__(self, input_size, hidden_size, factor): super().__init__() self.factor = factor - self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1)) - self.conv = nn.ModuleList([ - weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)), - weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)), - weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)), + self.res_block = Conv1d(input_size, hidden_size, 1) + self.main_block = nn.ModuleList([ + Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), + Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), + Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), ]) def forward(self, x): size = x.shape[-1] // self.factor + res = self.res_block(x) + res = F.interpolate(res, size=size) + o = F.interpolate(x, size=size) + for layer in self.main_block: + o = F.leaky_relu(o, 0.2) + o = layer(o) + return o + res - residual = self.residual_dense(x) - residual = F.interpolate(residual, size=size) - - x = F.interpolate(x, size=size) - for layer in self.conv: - x = F.leaky_relu(x, 0.2) - x = layer(x) - - return x + residual + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.res_block) + for _, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + nn.utils.remove_weight_norm(layer) + def apply_weight_norm(self): + self.res_block = weight_norm(self.res_block) + for idx, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + self.main_block[idx] = weight_norm(layer) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 9dc2193c..1130eb47 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -1,6 +1,7 @@ import numpy as np import torch from torch import nn +from torch.nn.utils import weight_norm from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d @@ -10,6 +11,7 @@ class Wavegrad(nn.Module): def __init__(self, in_channels=80, out_channels=1, + use_weight_norm=False, y_conv_channels=32, x_conv_channels=768, dblock_out_channels=[128, 128, 256, 512], @@ -19,6 +21,7 @@ class Wavegrad(nn.Module): [1, 2, 4, 8], [1, 2, 4, 8]]): super().__init__() + self.use_weight_norm = use_weight_norm self.hop_len = np.prod(upsample_factors) self.noise_level = None self.num_steps = None @@ -31,9 +34,8 @@ class Wavegrad(nn.Module): self.sigma = None # dblocks - self.dblocks = nn.ModuleList([ - Conv1d(1, y_conv_channels, 5, padding=2), - ]) + self.y_conv = Conv1d(1, y_conv_channels, 5, padding=2) + self.dblocks = nn.ModuleList([]) ic = y_conv_channels for oc, df in zip(dblock_out_channels, reversed(upsample_factors)): self.dblocks.append(DBlock(ic, oc, df)) @@ -56,15 +58,22 @@ class Wavegrad(nn.Module): self.x_conv = Conv1d(in_channels, x_conv_channels, 3, padding=1) self.out_conv = Conv1d(oc, out_channels, 3, padding=1) + if use_weight_norm: + self.apply_weight_norm() + def forward(self, x, spectrogram, noise_scale): - downsampled = [] - for film, layer in zip(self.film, self.dblocks): + shift_and_scale = [] + + x = self.y_conv(x) + shift_and_scale.append(self.film[0](x, noise_scale)) + + for film, layer in zip(self.film[1:], self.dblocks): x = layer(x) - downsampled.append(film(x, noise_scale)) + shift_and_scale.append(film(x, noise_scale)) x = self.x_conv(spectrogram) for layer, (film_shift, film_scale) in zip(self.ublocks, - reversed(downsampled)): + reversed(shift_and_scale)): x = layer(x, film_shift, film_scale) x = self.out_conv(x) return x @@ -113,3 +122,48 @@ class Wavegrad(nn.Module): self.c1 = 1 / self.alpha**0.5 self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5 self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5 + + def remove_weight_norm(self): + for _, layer in enumerate(self.dblocks): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + layer.remove_weight_norm() + + for _, layer in enumerate(self.film): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + layer.remove_weight_norm() + + + for _, layer in enumerate(self.ublocks): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + layer.remove_weight_norm() + + nn.utils.remove_weight_norm(self.x_conv) + nn.utils.remove_weight_norm(self.out_conv) + nn.utils.remove_weight_norm(self.y_conv) + + def apply_weight_norm(self): + for _, layer in enumerate(self.dblocks): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + for _, layer in enumerate(self.film): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + + for _, layer in enumerate(self.ublocks): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + self.x_conv = weight_norm(self.x_conv) + self.out_conv = weight_norm(self.out_conv) + self.y_conv = weight_norm(self.y_conv) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 761b14d7..d6e2e13b 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -118,6 +118,7 @@ def setup_generator(c): model = MyModel( in_channels=c['audio']['num_mels'], out_channels=1, + use_weight_norm=c['model_params']['use_weight_norm'], x_conv_channels=c['model_params']['x_conv_channels'], y_conv_channels=c['model_params']['y_conv_channels'], dblock_out_channels=c['model_params']['dblock_out_channels'], diff --git a/tests/test_wavegrad_layers.py b/tests/test_wavegrad_layers.py index a1c6a7e5..d81ae47d 100644 --- a/tests/test_wavegrad_layers.py +++ b/tests/test_wavegrad_layers.py @@ -32,6 +32,9 @@ def test_film(): assert scale.shape[2] == 100 assert isinstance(scale, torch.FloatTensor) + layer.apply_weight_norm() + layer.remove_weight_norm() + def test_ublock(): inp1 = torch.rand(32, 50, 100) @@ -49,6 +52,9 @@ def test_ublock(): assert o.shape[2] == 100 assert isinstance(o, torch.FloatTensor) + layer.apply_weight_norm() + layer.remove_weight_norm() + def test_dblock(): inp = torch.rand(32, 50, 130) @@ -60,6 +66,9 @@ def test_dblock(): assert o.shape[2] == 65 assert isinstance(o, torch.FloatTensor) + layer.apply_weight_norm() + layer.remove_weight_norm() + def test_wavegrad_forward(): x = torch.rand(32, 1, 20 * 300) @@ -78,3 +87,6 @@ def test_wavegrad_forward(): assert o.shape[1] == 1 assert o.shape[2] == 20 * 300 assert isinstance(o, torch.FloatTensor) + + model.apply_weight_norm() + model.remove_weight_norm()