hifigan implementation update

This commit is contained in:
Eren Gölge 2021-04-05 11:32:51 +02:00
parent a14d7bc5db
commit 8c9e1c9e58
6 changed files with 235 additions and 126 deletions

View File

@ -2,26 +2,32 @@ from torch import nn
class ResStack(nn.Module):
def __init__(self, kernel, channel, padding, dilations = [1, 3, 5]):
def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]):
super(ResStack, self).__init__()
resstack = []
for dilation in dilations:
resstack += [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(padding),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
]
resstack += [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
nn.utils.weight_norm(
nn.Conv1d(channel,
channel,
kernel_size=kernel,
dilation=dilation)),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(padding),
nn.utils.weight_norm(nn.Conv1d(channel, channel,
kernel_size=1)),
]
self.resstack = nn.Sequential(*resstack)
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
self.shortcut = nn.utils.weight_norm(
nn.Conv1d(channel, channel, kernel_size=1))
def forward(self, x):
x1 = self.shortcut(x)
x2 = self.resstack(x)
return x1 + x2
x1 = self.shortcut(x)
x2 = self.resstack(x)
return x1 + x2
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.shortcut)
@ -32,18 +38,19 @@ class ResStack(nn.Module):
nn.utils.remove_weight_norm(self.resstack[14])
nn.utils.remove_weight_norm(self.resstack[17])
class MRF(nn.Module):
def __init__(self, kernels, channel, dilations = [[1,1], [3,1], [5,1]]):
super(MRF, self).__init__()
self.resblock1 = ResStack(kernels[0], channel, 0)
self.resblock2 = ResStack(kernels[1], channel, 6)
self.resblock3 = ResStack(kernels[2], channel, 12)
def __init__(self, kernels, channel, dilations=[1, 3, 5]):
super().__init__()
self.resblock1 = ResStack(kernels[0], channel, 0, dilations)
self.resblock2 = ResStack(kernels[1], channel, 6, dilations)
self.resblock3 = ResStack(kernels[2], channel, 12, dilations)
def forward(self, x):
x1 = self.resblock1(x)
x2 = self.resblock2(x)
x3 = self.resblock3(x)
return x1 + x2 + x3
x1 = self.resblock1(x)
x2 = self.resblock2(x)
x3 = self.resblock3(x)
return x1 + x2 + x3
def remove_weight_norm(self):
self.resblock1.remove_weight_norm()

View File

@ -1,71 +1,174 @@
import torch
from torch import nn
from TTS.vocoder.layers.hifigan import MRF
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
LRELU_SLOPE = 0.1
class HifiganGenerator(nn.Module):
def get_padding(k, d):
return int((k * d - d) / 2)
def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4],
resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]):
super(HifiganGenerator, self).__init__()
self.inference_padding = 2
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = nn.ModuleList([
weight_norm(
Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(
Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(
Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.input = nn.Sequential(
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7))
)
generator = []
for k in upsample_kernel:
inp = base_channels
out = int(inp / 2)
generator += [
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)),
MRF(resblock_kernel_sizes, out, resblock_dilation_sizes)
]
base_channels = out
self.generator = nn.Sequential(*generator)
self.output = nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(base_channels, out_channels, kernel_size=7, stride=1)),
nn.Tanh()
)
self.convs2 = nn.ModuleList([
weight_norm(
Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(
Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(
Conv1d(channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1)))
])
def forward(self, x):
x1 = self.input(x)
x2 = self.generator(x1)
out = self.output(x2)
return out
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.convs = nn.ModuleList([
weight_norm(
Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(
Conv1d(channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1])))
])
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class HifiganGenerator(torch.nn.Module):
def __init__(self, in_channels, out_channels, resblock_type, resblock_dilation_sizes,
resblock_kernel_sizes, upsample_kernel_sizes,
upsample_initial_channel, upsample_factors):
super().__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_factors)
self.conv_pre = weight_norm(
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if resblock_type == '1' else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_factors,
upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(upsample_initial_channel // (2**i),
upsample_initial_channel // (2**(i + 1)),
k,
u,
padding=(k - u) // 2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2**(i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def inference(self, c):
c = c.to(self.layers[1].weight.device)
c = c.to(self.conv_pre.weight.device)
c = torch.nn.functional.pad(
c,
(self.inference_padding, self.inference_padding),
'replicate')
c, (self.inference_padding, self.inference_padding), 'replicate')
return self.forward(c)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input[1])
nn.utils.remove_weight_norm(self.output[2])
for idx, layer in enumerate(self.generator):
if len(layer.state_dict()) != 0:
try:
nn.utils.remove_weight_norm(layer)
except:
layer.remove_weight_norm()
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)

View File

@ -10,7 +10,9 @@ class MelganDiscriminator(nn.Module):
kernel_sizes=(5, 3),
base_channels=16,
max_channels=1024,
downsample_factors=(4, 4, 4, 4)):
downsample_factors=(4, 4, 4, 4),
groups_denominator=4,
max_groups=256):
super(MelganDiscriminator, self).__init__()
self.layers = nn.ModuleList()
@ -35,7 +37,7 @@ class MelganDiscriminator(nn.Module):
max_channels)
layer_kernel_size = downsample_factor * 10 + 1
layer_padding = (layer_kernel_size - 1) // 2
layer_groups = layer_in_channels // 4
layer_groups = layer_in_channels // groups_denominator
self.layers += [
nn.Sequential(
weight_norm(

View File

@ -14,7 +14,9 @@ class MelganMultiscaleDiscriminator(nn.Module):
downsample_factors=(4, 4, 4),
pooling_kernel_size=4,
pooling_stride=2,
pooling_padding=1):
pooling_padding=2,
groups_denominator=4,
max_groups=256):
super(MelganMultiscaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
@ -23,12 +25,16 @@ class MelganMultiscaleDiscriminator(nn.Module):
kernel_sizes=kernel_sizes,
base_channels=base_channels,
max_channels=max_channels,
downsample_factors=downsample_factors)
downsample_factors=downsample_factors,
groups_denominator=groups_denominator,
max_groups=max_groups)
for _ in range(num_scales)
])
self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False)
self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size,
stride=pooling_stride,
padding=pooling_padding,
count_include_pad=False)
def forward(self, x):
scores = list()

View File

@ -2,17 +2,18 @@ from torch import nn
import torch.nn.functional as F
from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator
class PeriodDiscriminator(nn.Module):
class PeriodDiscriminator(nn.Module):
def __init__(self, period):
super(PeriodDiscriminator, self).__init__()
layer = []
self.period = period
inp = 1
for l in range(4):
out = int(2 ** (5 + l + 1))
out = int(2**(5 + l + 1))
layer += [
nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
nn.utils.weight_norm(
nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
nn.LeakyReLU(0.2)
]
inp = out
@ -20,8 +21,7 @@ class PeriodDiscriminator(nn.Module):
self.output = nn.Sequential(
nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))
)
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))))
def forward(self, x):
batch_size = x.shape[0]
@ -33,39 +33,41 @@ class PeriodDiscriminator(nn.Module):
return self.output(out1)
class MultiPeriodDiscriminator(nn.Module):
class HifiDiscriminator(nn.Module):
def __init__(self,
periods=[2, 3, 5, 7, 11],
in_channels=1,
out_channels=1,
num_scales=3,
kernel_sizes=(5, 3),
base_channels=16,
base_channels=64,
max_channels=1024,
downsample_factors=(4, 4, 4),
downsample_factors=(2, 2, 4, 4),
pooling_kernel_size=4,
pooling_stride=2,
pooling_padding=1):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([ PeriodDiscriminator(periods[0]),
PeriodDiscriminator(periods[1]),
PeriodDiscriminator(periods[2]),
PeriodDiscriminator(periods[3]),
PeriodDiscriminator(periods[4])
])
super().__init__()
self.discriminators = nn.ModuleList([
PeriodDiscriminator(periods[0]),
PeriodDiscriminator(periods[1]),
PeriodDiscriminator(periods[2]),
PeriodDiscriminator(periods[3]),
PeriodDiscriminator(periods[4])
])
self.msd = MelganMultiscaleDiscriminator(
in_channels=in_channels,
out_channels=out_channels,
num_scales=num_scales,
kernel_sizes=kernel_sizes,
base_channels=base_channels,
max_channels=max_channels,
downsample_factors=downsample_factors,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding
)
in_channels=in_channels,
out_channels=out_channels,
num_scales=num_scales,
kernel_sizes=kernel_sizes,
base_channels=base_channels,
max_channels=max_channels,
downsample_factors=downsample_factors,
pooling_kernel_size=pooling_kernel_size,
pooling_stride=pooling_stride,
pooling_padding=pooling_padding,
groups_denominator=32,
max_groups=16)
def forward(self, x):
scores, feats = self.msd(x)

View File

@ -95,11 +95,8 @@ def setup_generator(c):
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=1,
base_channels=c.generator_model_params['upsample_initial_channel'],
upsample_kernel=c.generator_model_params['upsample_kernel_sizes'],
resblock_kernel_sizes=c.generator_model_params['resblock_kernel_sizes'],
resblock_dilation_sizes=c.generator_model_params['resblock_dilation_sizes'])
elif c.generator_model.lower() in 'melgan_generator':
**c.generator_model_params)
if c.generator_model.lower() in 'melgan_generator':
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=1,
@ -170,16 +167,8 @@ def setup_discriminator(c):
MyModel = importlib.import_module('TTS.vocoder.models.' +
c.discriminator_model.lower())
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
if c.discriminator_model in 'multi_period_discriminator':
model = MyModel(
periods=c.discriminator_model_params['peroids'],
in_channels=1,
out_channels=1,
kernel_sizes=(5, 3),
base_channels=c.discriminator_model_params['base_channels'],
max_channels=c.discriminator_model_params['max_channels'],
downsample_factors=c.
discriminator_model_params['downsample_factors'])
if c.discriminator_model in 'hifigan_discriminator':
model = MyModel()
if c.discriminator_model in 'random_window_discriminator':
model = MyModel(
cond_channels=c.audio['num_mels'],