pwgan files

This commit is contained in:
erogol 2020-07-17 11:36:36 +02:00
parent 9bd415fbc9
commit 320bc29496
8 changed files with 641 additions and 0 deletions

View File

@ -0,0 +1,75 @@
import torch
from torch.nn import functional as F
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(self,
kernel_size=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
dilation=1,
bias=True,
use_causal_conv=False
):
super(ResidualBlock, self).__init__()
self.dropout = dropout
# no future time stamps available
if use_causal_conv:
padding = (kernel_size - 1) * dilation
else:
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
padding = (kernel_size - 1) // 2 * dilation
self.use_causal_conv = use_causal_conv
# dilation conv
self.conv = torch.nn.Conv1d(res_channels, gate_channels, kernel_size,
padding=padding, dilation=dilation, bias=bias)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
else:
self.conv1x1_aux = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
def forward(self, x, c):
"""
x: B x D_res x T
c: B x D_aux x T
"""
residual = x
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.conv(x)
# remove future time steps if use_causal_conv conv
x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
# split into two part for gated activation
splitdim = 1
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
assert self.conv1x1_aux is not None
c = self.conv1x1_aux(c)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
xa, xb = xa + ca, xb + cb
x = torch.tanh(xa) * torch.sigmoid(xb)
# for skip connection
s = self.conv1x1_skip(x)
# for residual connection
x = (self.conv1x1_out(x) + residual) * (0.5 ** 2)
return x, s

View File

@ -0,0 +1,100 @@
import numpy as np
import torch
from torch.nn import functional as F
class Stretch2d(torch.nn.Module):
def __init__(self, x_scale, y_scale, mode="nearest"):
super(Stretch2d, self).__init__()
self.x_scale = x_scale
self.y_scale = y_scale
self.mode = mode
def forward(self, x):
"""
x (Tensor): Input tensor (B, C, F, T).
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
"""
return F.interpolate(
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
class UpsampleNetwork(torch.nn.Module):
def __init__(self,
upsample_factors,
nonlinear_activation=None,
nonlinear_activation_params={},
interpolate_mode="nearest",
freq_axis_kernel_size=1,
use_causal_conv=False,
):
super(UpsampleNetwork, self).__init__()
self.use_causal_conv = use_causal_conv
self.up_layers = torch.nn.ModuleList()
for scale in upsample_factors:
# interpolation layer
stretch = Stretch2d(scale, 1, interpolate_mode)
self.up_layers += [stretch]
# conv layer
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
if use_causal_conv:
padding = (freq_axis_padding, scale * 2)
else:
padding = (freq_axis_padding, scale)
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
self.up_layers += [conv]
# nonlinear
if nonlinear_activation is not None:
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
self.up_layers += [nonlinear]
def forward(self, c):
"""
c : (B, C, T_in).
Tensor: (B, C, T_upsample)
"""
c = c.unsqueeze(1) # (B, 1, C, T)
for f in self.up_layers:
c = f(c)
return c.squeeze(1) # (B, C, T')
class ConvUpsample(torch.nn.Module):
def __init__(self,
upsample_factors,
nonlinear_activation=None,
nonlinear_activation_params={},
interpolate_mode="nearest",
freq_axis_kernel_size=1,
aux_channels=80,
aux_context_window=0,
use_causal_conv=False
):
super(ConvUpsample, self).__init__()
self.aux_context_window = aux_context_window
self.use_causal_conv = use_causal_conv and aux_context_window > 0
# To capture wide-context information in conditional features
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
self.upsample = UpsampleNetwork(
upsample_factors=upsample_factors,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
interpolate_mode=interpolate_mode,
freq_axis_kernel_size=freq_axis_kernel_size,
use_causal_conv=use_causal_conv,
)
def forward(self, c):
"""
c : (B, C, T_in).
Tensor: (B, C, T_upsampled),
"""
c_ = self.conv_in(c)
c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
return self.upsample(c)

View File

@ -0,0 +1,192 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
class ParallelWaveganDiscriminator(nn.Module):
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
It classifies each audio window real/fake and returns a sequence
of predictions.
It is a stack of convolutional blocks with dilation.
"""
def __init__(self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=10,
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True,
):
super(ParallelWaveganDiscriminator, self).__init__()
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size."
assert dilation_factor > 0, " [!] dilation factor must be > 0."
self.conv_layers = nn.ModuleList()
conv_in_channels = in_channels
for i in range(num_layers - 1):
if i == 0:
dilation = 1
else:
dilation = i if dilation_factor == 1 else dilation_factor ** i
conv_in_channels = conv_channels
padding = (kernel_size - 1) // 2 * dilation
conv_layer = [
nn.Conv1d(conv_in_channels, conv_channels,
kernel_size=kernel_size, padding=padding,
dilation=dilation, bias=bias),
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
]
self.conv_layers += conv_layer
padding = (kernel_size - 1) // 2
last_conv_layer = nn.Conv1d(
conv_in_channels, out_channels,
kernel_size=kernel_size, padding=padding, bias=bias)
self.conv_layers += [last_conv_layer]
self.apply_weight_norm()
def forward(self, x):
"""
x : (B, 1, T).
Returns:
Tensor: (B, 1, T)
"""
for f in self.conv_layers:
x = f(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
class ResidualParallelWaveganDiscriminator(nn.Module):
def __init__(self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
dropout=0.0,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
):
super(ResidualParallelWaveganDiscriminator, self).__init__()
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
self.in_channels = in_channels
self.out_channels = out_channels
self.num_layers = num_layers
self.stacks = stacks
self.kernel_size = kernel_size
self.res_factor = math.sqrt(1.0 / num_layers)
# check the number of num_layers and stacks
assert num_layers % stacks == 0
layers_per_stack = num_layers // stacks
# define first convolution
self.first_conv = nn.Sequential(
nn.Conv1d(in_channels,
res_channels,
kernel_size=1,
padding=0,
dilation=1,
bias=True),
getattr(nn, nonlinear_activation)(inplace=True,
**nonlinear_activation_params),
)
# define residual blocks
self.conv_layers = nn.ModuleList()
for layer in range(num_layers):
dilation = 2 ** (layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
res_channels=res_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=-1,
dilation=dilation,
dropout=dropout,
bias=bias,
use_causal_conv=False,
)
self.conv_layers += [conv]
# define output layers
self.last_conv_layers = nn.ModuleList([
getattr(nn, nonlinear_activation)(inplace=True,
**nonlinear_activation_params),
nn.Conv1d(skip_channels,
skip_channels,
kernel_size=1,
padding=0,
dilation=1,
bias=True),
getattr(nn, nonlinear_activation)(inplace=True,
**nonlinear_activation_params),
nn.Conv1d(skip_channels,
out_channels,
kernel_size=1,
padding=0,
dilation=1,
bias=True),
])
# apply weight norm
self.apply_weight_norm()
def forward(self, x):
"""
x: (B, 1, T).
"""
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, None)
skips += h
skips *= self.res_factor
# apply final layers
x = skips
for f in self.last_conv_layers:
x = f(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
print(f"Weight norm is removed from {m}.")
nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)

View File

@ -0,0 +1,162 @@
import math
import numpy as np
import torch
from torch.nn.utils import weight_norm
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
class ParallelWaveganGenerator(torch.nn.Module):
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
It is similar to WaveNet with no causal convolution.
It is conditioned on an aux feature (spectrogram) to generate
an output waveform from an input noise.
"""
def __init__(self,
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
aux_context_window=2,
dropout=0.0,
bias=True,
use_weight_norm=True,
use_causal_conv=False,
upsample_conditional_features=True,
upsample_net="ConvInUpsampleNetwork",
upsample_factors=[4, 4, 4, 4],
inference_padding=2):
super(ParallelWaveganGenerator, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.num_res_blocks = num_res_blocks
self.stacks = stacks
self.kernel_size = kernel_size
self.upsample_factors = upsample_factors
self.upsample_scale = np.prod(upsample_factors)
self.inference_padding = inference_padding
# check the number of layers and stacks
assert num_res_blocks % stacks == 0
layers_per_stack = num_res_blocks // stacks
# define first convolution
self.first_conv = torch.nn.Conv1d(in_channels,
res_channels,
kernel_size=1,
bias=True)
# define conv + upsampling network
self.upsample_net = ConvUpsample(upsample_factors=upsample_factors)
# define residual blocks
self.conv_layers = torch.nn.ModuleList()
for layer in range(num_res_blocks):
dilation = 2**(layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
res_channels=res_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias,
)
self.conv_layers += [conv]
# define output layers
self.last_conv_layers = torch.nn.ModuleList([
torch.nn.ReLU(inplace=True),
torch.nn.Conv1d(skip_channels,
skip_channels,
kernel_size=1,
bias=True),
torch.nn.ReLU(inplace=True),
torch.nn.Conv1d(skip_channels,
out_channels,
kernel_size=1,
bias=True),
])
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(self, c):
"""
c: (B, C ,T').
o: Output tensor (B, out_channels, T)
"""
# random noise
x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale])
x = x.to(self.first_conv.bias.device)
# perform upsampling
if c is not None and self.upsample_net is not None:
c = self.upsample_net(c)
assert c.shape[-1] == x.shape[
-1], f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
# encode to hidden representation
x = self.first_conv(x)
skips = 0
for f in self.conv_layers:
x, h = f(x, c)
skips += h
skips *= math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
x = skips
for f in self.last_conv_layers:
x = f(x)
return x
def inference(self, c):
c = c.to(self.first_conv.weight.device)
c = torch.nn.functional.pad(
c, (self.inference_padding, self.inference_padding), 'replicate')
return self.forward(c)
def remove_weight_norm(self):
def _remove_weight_norm(m):
try:
# print(f"Weight norm is removed from {m}.")
torch.nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(_remove_weight_norm)
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(
m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
# print(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers,
stacks,
kernel_size,
dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
return self._get_receptive_field_size(self.layers, self.stacks,
self.kernel_size)

View File

@ -0,0 +1,41 @@
import numpy as np
import torch
from TTS.vocoder.models.parallel_wavegan_discriminator import ParallelWaveganDiscriminator, ResidualParallelWaveganDiscriminator
def test_pwgan_disciminator():
model = ParallelWaveganDiscriminator(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=10,
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True)
dummy_x = torch.rand((4, 1, 64 * 256))
output = model(dummy_x)
assert np.all(output.shape == (4, 1, 64 * 256))
model.remove_weight_norm()
def test_redisual_pwgan_disciminator():
model = ResidualParallelWaveganDiscriminator(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
dropout=0.0,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2})
dummy_x = torch.rand((4, 1, 64 * 256))
output = model(dummy_x)
assert np.all(output.shape == (4, 1, 64 * 256))
model.remove_weight_norm()

View File

@ -0,0 +1,30 @@
import numpy as np
import torch
from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator
def test_pwgan_generator():
model = ParallelWaveganGenerator(
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
aux_context_window=2,
dropout=0.0,
bias=True,
use_weight_norm=True,
use_causal_conv=False,
upsample_conditional_features=True,
upsample_factors=[4, 4, 4, 4])
dummy_c = torch.rand((4, 80, 64))
output = model(dummy_c)
assert np.all(output.shape == (4, 1, 64 * 256))
model.remove_weight_norm()
output = model.inference(dummy_c)
assert np.all(output.shape == (4, 1, (64 + 4) * 256))

View File

@ -0,0 +1,12 @@
import numpy as np
import tensorflow as tf
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
def test_melgan_generator():
hop_length = 256
model = MelganGenerator()
dummy_input = tf.random.uniform((4, 80, 64))
output = model(dummy_input, training=False)
assert np.all(output.shape == (4, 1, 64 * hop_length)), output.shape

View File

@ -0,0 +1,29 @@
import os
import tensorflow as tf
import soundfile as sf
from librosa.core import load
from TTS.tests import get_tests_path, get_tests_input_path
from TTS.vocoder.tf.layers.pqmf import PQMF
TESTS_PATH = get_tests_path()
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
def test_pqmf():
w, sr = load(WAV_FILE)
layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
w, sr = load(WAV_FILE)
w2 = tf.convert_to_tensor(w[None, None, :])
b2 = layer.analysis(w2)
w2_ = layer.synthesis(b2)
w2_ = w2.numpy()
print(w2_.max())
print(w2_.min())
print(w2_.mean())
sf.write('tf_pqmf_output.wav', w2_.flatten(), sr)