mirror of https://github.com/coqui-ai/TTS.git
hifigan implementation update
This commit is contained in:
parent
a14d7bc5db
commit
8c9e1c9e58
|
@ -2,26 +2,32 @@ from torch import nn
|
|||
|
||||
|
||||
class ResStack(nn.Module):
|
||||
def __init__(self, kernel, channel, padding, dilations = [1, 3, 5]):
|
||||
def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]):
|
||||
super(ResStack, self).__init__()
|
||||
resstack = []
|
||||
for dilation in dilations:
|
||||
resstack += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(dilation),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(padding),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
||||
]
|
||||
resstack += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(dilation),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(channel,
|
||||
channel,
|
||||
kernel_size=kernel,
|
||||
dilation=dilation)),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(padding),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel,
|
||||
kernel_size=1)),
|
||||
]
|
||||
self.resstack = nn.Sequential(*resstack)
|
||||
|
||||
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
||||
self.shortcut = nn.utils.weight_norm(
|
||||
nn.Conv1d(channel, channel, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.shortcut(x)
|
||||
x2 = self.resstack(x)
|
||||
return x1 + x2
|
||||
x1 = self.shortcut(x)
|
||||
x2 = self.resstack(x)
|
||||
return x1 + x2
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.shortcut)
|
||||
|
@ -32,18 +38,19 @@ class ResStack(nn.Module):
|
|||
nn.utils.remove_weight_norm(self.resstack[14])
|
||||
nn.utils.remove_weight_norm(self.resstack[17])
|
||||
|
||||
|
||||
class MRF(nn.Module):
|
||||
def __init__(self, kernels, channel, dilations = [[1,1], [3,1], [5,1]]):
|
||||
super(MRF, self).__init__()
|
||||
self.resblock1 = ResStack(kernels[0], channel, 0)
|
||||
self.resblock2 = ResStack(kernels[1], channel, 6)
|
||||
self.resblock3 = ResStack(kernels[2], channel, 12)
|
||||
def __init__(self, kernels, channel, dilations=[1, 3, 5]):
|
||||
super().__init__()
|
||||
self.resblock1 = ResStack(kernels[0], channel, 0, dilations)
|
||||
self.resblock2 = ResStack(kernels[1], channel, 6, dilations)
|
||||
self.resblock3 = ResStack(kernels[2], channel, 12, dilations)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.resblock1(x)
|
||||
x2 = self.resblock2(x)
|
||||
x3 = self.resblock3(x)
|
||||
return x1 + x2 + x3
|
||||
x1 = self.resblock1(x)
|
||||
x2 = self.resblock2(x)
|
||||
x3 = self.resblock3(x)
|
||||
return x1 + x2 + x3
|
||||
|
||||
def remove_weight_norm(self):
|
||||
self.resblock1.remove_weight_norm()
|
||||
|
|
|
@ -1,71 +1,174 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from TTS.vocoder.layers.hifigan import MRF
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class HifiganGenerator(nn.Module):
|
||||
def get_padding(k, d):
|
||||
return int((k * d - d) / 2)
|
||||
|
||||
def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4],
|
||||
resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]):
|
||||
super(HifiganGenerator, self).__init__()
|
||||
|
||||
self.inference_padding = 2
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super().__init__()
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2])))
|
||||
])
|
||||
|
||||
self.input = nn.Sequential(
|
||||
nn.ReflectionPad1d(3),
|
||||
nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7))
|
||||
)
|
||||
|
||||
generator = []
|
||||
|
||||
for k in upsample_kernel:
|
||||
inp = base_channels
|
||||
out = int(inp / 2)
|
||||
generator += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)),
|
||||
MRF(resblock_kernel_sizes, out, resblock_dilation_sizes)
|
||||
]
|
||||
base_channels = out
|
||||
self.generator = nn.Sequential(*generator)
|
||||
|
||||
self.output = nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(3),
|
||||
nn.utils.weight_norm(nn.Conv1d(base_channels, out_channels, kernel_size=7, stride=1)),
|
||||
nn.Tanh()
|
||||
|
||||
)
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.input(x)
|
||||
x2 = self.generator(x1)
|
||||
out = self.output(x2)
|
||||
return out
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class ResBlock2(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||
super().__init__()
|
||||
self.convs = nn.ModuleList([
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(
|
||||
Conv1d(channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1])))
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for c in self.convs:
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
|
||||
class HifiganGenerator(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, resblock_type, resblock_dilation_sizes,
|
||||
resblock_kernel_sizes, upsample_kernel_sizes,
|
||||
upsample_initial_channel, upsample_factors):
|
||||
super().__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_factors)
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||
resblock = ResBlock1 if resblock_type == '1' else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_factors,
|
||||
upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(upsample_initial_channel // (2**i),
|
||||
upsample_initial_channel // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2)))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
def inference(self, c):
|
||||
c = c.to(self.layers[1].weight.device)
|
||||
c = c.to(self.conv_pre.weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
c,
|
||||
(self.inference_padding, self.inference_padding),
|
||||
'replicate')
|
||||
c, (self.inference_padding, self.inference_padding), 'replicate')
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.input[1])
|
||||
nn.utils.remove_weight_norm(self.output[2])
|
||||
|
||||
for idx, layer in enumerate(self.generator):
|
||||
if len(layer.state_dict()) != 0:
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
|
|
@ -10,7 +10,9 @@ class MelganDiscriminator(nn.Module):
|
|||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4, 4)):
|
||||
downsample_factors=(4, 4, 4, 4),
|
||||
groups_denominator=4,
|
||||
max_groups=256):
|
||||
super(MelganDiscriminator, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
|
@ -35,7 +37,7 @@ class MelganDiscriminator(nn.Module):
|
|||
max_channels)
|
||||
layer_kernel_size = downsample_factor * 10 + 1
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
layer_groups = layer_in_channels // 4
|
||||
layer_groups = layer_in_channels // groups_denominator
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
|
|
|
@ -14,7 +14,9 @@ class MelganMultiscaleDiscriminator(nn.Module):
|
|||
downsample_factors=(4, 4, 4),
|
||||
pooling_kernel_size=4,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1):
|
||||
pooling_padding=2,
|
||||
groups_denominator=4,
|
||||
max_groups=256):
|
||||
super(MelganMultiscaleDiscriminator, self).__init__()
|
||||
|
||||
self.discriminators = nn.ModuleList([
|
||||
|
@ -23,12 +25,16 @@ class MelganMultiscaleDiscriminator(nn.Module):
|
|||
kernel_sizes=kernel_sizes,
|
||||
base_channels=base_channels,
|
||||
max_channels=max_channels,
|
||||
downsample_factors=downsample_factors)
|
||||
downsample_factors=downsample_factors,
|
||||
groups_denominator=groups_denominator,
|
||||
max_groups=max_groups)
|
||||
for _ in range(num_scales)
|
||||
])
|
||||
|
||||
self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False)
|
||||
|
||||
self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size,
|
||||
stride=pooling_stride,
|
||||
padding=pooling_padding,
|
||||
count_include_pad=False)
|
||||
|
||||
def forward(self, x):
|
||||
scores = list()
|
||||
|
|
|
@ -2,17 +2,18 @@ from torch import nn
|
|||
import torch.nn.functional as F
|
||||
from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator
|
||||
|
||||
class PeriodDiscriminator(nn.Module):
|
||||
|
||||
class PeriodDiscriminator(nn.Module):
|
||||
def __init__(self, period):
|
||||
super(PeriodDiscriminator, self).__init__()
|
||||
layer = []
|
||||
self.period = period
|
||||
inp = 1
|
||||
for l in range(4):
|
||||
out = int(2 ** (5 + l + 1))
|
||||
out = int(2**(5 + l + 1))
|
||||
layer += [
|
||||
nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
|
||||
nn.LeakyReLU(0.2)
|
||||
]
|
||||
inp = out
|
||||
|
@ -20,8 +21,7 @@ class PeriodDiscriminator(nn.Module):
|
|||
self.output = nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))
|
||||
)
|
||||
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))))
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
|
@ -33,39 +33,41 @@ class PeriodDiscriminator(nn.Module):
|
|||
return self.output(out1)
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(nn.Module):
|
||||
class HifiDiscriminator(nn.Module):
|
||||
def __init__(self,
|
||||
periods=[2, 3, 5, 7, 11],
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
num_scales=3,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
base_channels=64,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4),
|
||||
downsample_factors=(2, 2, 4, 4),
|
||||
pooling_kernel_size=4,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([ PeriodDiscriminator(periods[0]),
|
||||
PeriodDiscriminator(periods[1]),
|
||||
PeriodDiscriminator(periods[2]),
|
||||
PeriodDiscriminator(periods[3]),
|
||||
PeriodDiscriminator(periods[4])
|
||||
])
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
PeriodDiscriminator(periods[0]),
|
||||
PeriodDiscriminator(periods[1]),
|
||||
PeriodDiscriminator(periods[2]),
|
||||
PeriodDiscriminator(periods[3]),
|
||||
PeriodDiscriminator(periods[4])
|
||||
])
|
||||
|
||||
self.msd = MelganMultiscaleDiscriminator(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
num_scales=num_scales,
|
||||
kernel_sizes=kernel_sizes,
|
||||
base_channels=base_channels,
|
||||
max_channels=max_channels,
|
||||
downsample_factors=downsample_factors,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
pooling_stride=pooling_stride,
|
||||
pooling_padding=pooling_padding
|
||||
)
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
num_scales=num_scales,
|
||||
kernel_sizes=kernel_sizes,
|
||||
base_channels=base_channels,
|
||||
max_channels=max_channels,
|
||||
downsample_factors=downsample_factors,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
pooling_stride=pooling_stride,
|
||||
pooling_padding=pooling_padding,
|
||||
groups_denominator=32,
|
||||
max_groups=16)
|
||||
|
||||
def forward(self, x):
|
||||
scores, feats = self.msd(x)
|
||||
|
|
|
@ -95,11 +95,8 @@ def setup_generator(c):
|
|||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=1,
|
||||
base_channels=c.generator_model_params['upsample_initial_channel'],
|
||||
upsample_kernel=c.generator_model_params['upsample_kernel_sizes'],
|
||||
resblock_kernel_sizes=c.generator_model_params['resblock_kernel_sizes'],
|
||||
resblock_dilation_sizes=c.generator_model_params['resblock_dilation_sizes'])
|
||||
elif c.generator_model.lower() in 'melgan_generator':
|
||||
**c.generator_model_params)
|
||||
if c.generator_model.lower() in 'melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=1,
|
||||
|
@ -170,16 +167,8 @@ def setup_discriminator(c):
|
|||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in 'multi_period_discriminator':
|
||||
model = MyModel(
|
||||
periods=c.discriminator_model_params['peroids'],
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params['base_channels'],
|
||||
max_channels=c.discriminator_model_params['max_channels'],
|
||||
downsample_factors=c.
|
||||
discriminator_model_params['downsample_factors'])
|
||||
if c.discriminator_model in 'hifigan_discriminator':
|
||||
model = MyModel()
|
||||
if c.discriminator_model in 'random_window_discriminator':
|
||||
model = MyModel(
|
||||
cond_channels=c.audio['num_mels'],
|
||||
|
|
Loading…
Reference in New Issue