diff --git a/TTS/vocoder/layers/hifigan.py b/TTS/vocoder/layers/hifigan.py index 28fbb7b0..c632d079 100644 --- a/TTS/vocoder/layers/hifigan.py +++ b/TTS/vocoder/layers/hifigan.py @@ -2,26 +2,32 @@ from torch import nn class ResStack(nn.Module): - def __init__(self, kernel, channel, padding, dilations = [1, 3, 5]): + def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]): super(ResStack, self).__init__() resstack = [] for dilation in dilations: - resstack += [ - nn.LeakyReLU(0.2), - nn.ReflectionPad1d(dilation), - nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)), - nn.LeakyReLU(0.2), - nn.ReflectionPad1d(padding), - nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), - ] + resstack += [ + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(dilation), + nn.utils.weight_norm( + nn.Conv1d(channel, + channel, + kernel_size=kernel, + dilation=dilation)), + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(padding), + nn.utils.weight_norm(nn.Conv1d(channel, channel, + kernel_size=1)), + ] self.resstack = nn.Sequential(*resstack) - self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) + self.shortcut = nn.utils.weight_norm( + nn.Conv1d(channel, channel, kernel_size=1)) def forward(self, x): - x1 = self.shortcut(x) - x2 = self.resstack(x) - return x1 + x2 + x1 = self.shortcut(x) + x2 = self.resstack(x) + return x1 + x2 def remove_weight_norm(self): nn.utils.remove_weight_norm(self.shortcut) @@ -32,18 +38,19 @@ class ResStack(nn.Module): nn.utils.remove_weight_norm(self.resstack[14]) nn.utils.remove_weight_norm(self.resstack[17]) + class MRF(nn.Module): - def __init__(self, kernels, channel, dilations = [[1,1], [3,1], [5,1]]): - super(MRF, self).__init__() - self.resblock1 = ResStack(kernels[0], channel, 0) - self.resblock2 = ResStack(kernels[1], channel, 6) - self.resblock3 = ResStack(kernels[2], channel, 12) + def __init__(self, kernels, channel, dilations=[1, 3, 5]): + super().__init__() + self.resblock1 = ResStack(kernels[0], channel, 0, dilations) + self.resblock2 = ResStack(kernels[1], channel, 6, dilations) + self.resblock3 = ResStack(kernels[2], channel, 12, dilations) def forward(self, x): - x1 = self.resblock1(x) - x2 = self.resblock2(x) - x3 = self.resblock3(x) - return x1 + x2 + x3 + x1 = self.resblock1(x) + x2 = self.resblock2(x) + x3 = self.resblock3(x) + return x1 + x2 + x3 def remove_weight_norm(self): self.resblock1.remove_weight_norm() diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index c81f3653..2b1f43f7 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -1,71 +1,174 @@ import torch -from torch import nn -from TTS.vocoder.layers.hifigan import MRF +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +LRELU_SLOPE = 0.1 -class HifiganGenerator(nn.Module): +def get_padding(k, d): + return int((k * d - d) / 2) - def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4], - resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]): - super(HifiganGenerator, self).__init__() - self.inference_padding = 2 +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList([ + weight_norm( + Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm( + Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm( + Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) - self.input = nn.Sequential( - nn.ReflectionPad1d(3), - nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7)) - ) - - generator = [] - - for k in upsample_kernel: - inp = base_channels - out = int(inp / 2) - generator += [ - nn.LeakyReLU(0.2), - nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)), - MRF(resblock_kernel_sizes, out, resblock_dilation_sizes) - ] - base_channels = out - self.generator = nn.Sequential(*generator) - - self.output = nn.Sequential( - nn.LeakyReLU(0.2), - nn.ReflectionPad1d(3), - nn.utils.weight_norm(nn.Conv1d(base_channels, out_channels, kernel_size=7, stride=1)), - nn.Tanh() - - ) + self.convs2 = nn.ModuleList([ + weight_norm( + Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm( + Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm( + Conv1d(channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1))) + ]) def forward(self, x): - x1 = self.input(x) - x2 = self.generator(x1) - out = self.output(x2) - return out + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList([ + weight_norm( + Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm( + Conv1d(channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HifiganGenerator(torch.nn.Module): + def __init__(self, in_channels, out_channels, resblock_type, resblock_dilation_sizes, + resblock_kernel_sizes, upsample_kernel_sizes, + upsample_initial_channel, upsample_factors): + super().__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + self.conv_pre = weight_norm( + Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, + upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d(upsample_initial_channel // (2**i), + upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3)) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x def inference(self, c): - c = c.to(self.layers[1].weight.device) + c = c.to(self.conv_pre.weight.device) c = torch.nn.functional.pad( - c, - (self.inference_padding, self.inference_padding), - 'replicate') + c, (self.inference_padding, self.inference_padding), 'replicate') return self.forward(c) def remove_weight_norm(self): - nn.utils.remove_weight_norm(self.input[1]) - nn.utils.remove_weight_norm(self.output[2]) - - for idx, layer in enumerate(self.generator): - if len(layer.state_dict()) != 0: - try: - nn.utils.remove_weight_norm(layer) - except: - layer.remove_weight_norm() - - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) - if eval: - self.eval() - assert not self.training - self.remove_weight_norm() \ No newline at end of file + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/TTS/vocoder/models/melgan_discriminator.py b/TTS/vocoder/models/melgan_discriminator.py index 3847babb..8443a3b9 100644 --- a/TTS/vocoder/models/melgan_discriminator.py +++ b/TTS/vocoder/models/melgan_discriminator.py @@ -10,7 +10,9 @@ class MelganDiscriminator(nn.Module): kernel_sizes=(5, 3), base_channels=16, max_channels=1024, - downsample_factors=(4, 4, 4, 4)): + downsample_factors=(4, 4, 4, 4), + groups_denominator=4, + max_groups=256): super(MelganDiscriminator, self).__init__() self.layers = nn.ModuleList() @@ -35,7 +37,7 @@ class MelganDiscriminator(nn.Module): max_channels) layer_kernel_size = downsample_factor * 10 + 1 layer_padding = (layer_kernel_size - 1) // 2 - layer_groups = layer_in_channels // 4 + layer_groups = layer_in_channels // groups_denominator self.layers += [ nn.Sequential( weight_norm( diff --git a/TTS/vocoder/models/melgan_multiscale_discriminator.py b/TTS/vocoder/models/melgan_multiscale_discriminator.py index 0f9cca96..3ab6e13c 100644 --- a/TTS/vocoder/models/melgan_multiscale_discriminator.py +++ b/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -14,7 +14,9 @@ class MelganMultiscaleDiscriminator(nn.Module): downsample_factors=(4, 4, 4), pooling_kernel_size=4, pooling_stride=2, - pooling_padding=1): + pooling_padding=2, + groups_denominator=4, + max_groups=256): super(MelganMultiscaleDiscriminator, self).__init__() self.discriminators = nn.ModuleList([ @@ -23,12 +25,16 @@ class MelganMultiscaleDiscriminator(nn.Module): kernel_sizes=kernel_sizes, base_channels=base_channels, max_channels=max_channels, - downsample_factors=downsample_factors) + downsample_factors=downsample_factors, + groups_denominator=groups_denominator, + max_groups=max_groups) for _ in range(num_scales) ]) - self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False) - + self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size, + stride=pooling_stride, + padding=pooling_padding, + count_include_pad=False) def forward(self, x): scores = list() diff --git a/TTS/vocoder/models/multi_period_discriminator.py b/TTS/vocoder/models/multi_period_discriminator.py index 69d12be4..8f821a87 100644 --- a/TTS/vocoder/models/multi_period_discriminator.py +++ b/TTS/vocoder/models/multi_period_discriminator.py @@ -2,17 +2,18 @@ from torch import nn import torch.nn.functional as F from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator -class PeriodDiscriminator(nn.Module): +class PeriodDiscriminator(nn.Module): def __init__(self, period): super(PeriodDiscriminator, self).__init__() layer = [] self.period = period inp = 1 for l in range(4): - out = int(2 ** (5 + l + 1)) + out = int(2**(5 + l + 1)) layer += [ - nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))), + nn.utils.weight_norm( + nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))), nn.LeakyReLU(0.2) ] inp = out @@ -20,8 +21,7 @@ class PeriodDiscriminator(nn.Module): self.output = nn.Sequential( nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))), nn.LeakyReLU(0.2), - nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))) - ) + nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))) def forward(self, x): batch_size = x.shape[0] @@ -33,39 +33,41 @@ class PeriodDiscriminator(nn.Module): return self.output(out1) -class MultiPeriodDiscriminator(nn.Module): +class HifiDiscriminator(nn.Module): def __init__(self, periods=[2, 3, 5, 7, 11], in_channels=1, out_channels=1, num_scales=3, kernel_sizes=(5, 3), - base_channels=16, + base_channels=64, max_channels=1024, - downsample_factors=(4, 4, 4), + downsample_factors=(2, 2, 4, 4), pooling_kernel_size=4, pooling_stride=2, pooling_padding=1): - super(MultiPeriodDiscriminator, self).__init__() - self.discriminators = nn.ModuleList([ PeriodDiscriminator(periods[0]), - PeriodDiscriminator(periods[1]), - PeriodDiscriminator(periods[2]), - PeriodDiscriminator(periods[3]), - PeriodDiscriminator(periods[4]) - ]) + super().__init__() + self.discriminators = nn.ModuleList([ + PeriodDiscriminator(periods[0]), + PeriodDiscriminator(periods[1]), + PeriodDiscriminator(periods[2]), + PeriodDiscriminator(periods[3]), + PeriodDiscriminator(periods[4]) + ]) self.msd = MelganMultiscaleDiscriminator( - in_channels=in_channels, - out_channels=out_channels, - num_scales=num_scales, - kernel_sizes=kernel_sizes, - base_channels=base_channels, - max_channels=max_channels, - downsample_factors=downsample_factors, - pooling_kernel_size=pooling_kernel_size, - pooling_stride=pooling_stride, - pooling_padding=pooling_padding - ) + in_channels=in_channels, + out_channels=out_channels, + num_scales=num_scales, + kernel_sizes=kernel_sizes, + base_channels=base_channels, + max_channels=max_channels, + downsample_factors=downsample_factors, + pooling_kernel_size=pooling_kernel_size, + pooling_stride=pooling_stride, + pooling_padding=pooling_padding, + groups_denominator=32, + max_groups=16) def forward(self, x): scores, feats = self.msd(x) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 2fcc0171..7f4c187f 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -95,11 +95,8 @@ def setup_generator(c): model = MyModel( in_channels=c.audio['num_mels'], out_channels=1, - base_channels=c.generator_model_params['upsample_initial_channel'], - upsample_kernel=c.generator_model_params['upsample_kernel_sizes'], - resblock_kernel_sizes=c.generator_model_params['resblock_kernel_sizes'], - resblock_dilation_sizes=c.generator_model_params['resblock_dilation_sizes']) - elif c.generator_model.lower() in 'melgan_generator': + **c.generator_model_params) + if c.generator_model.lower() in 'melgan_generator': model = MyModel( in_channels=c.audio['num_mels'], out_channels=1, @@ -170,16 +167,8 @@ def setup_discriminator(c): MyModel = importlib.import_module('TTS.vocoder.models.' + c.discriminator_model.lower()) MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) - if c.discriminator_model in 'multi_period_discriminator': - model = MyModel( - periods=c.discriminator_model_params['peroids'], - in_channels=1, - out_channels=1, - kernel_sizes=(5, 3), - base_channels=c.discriminator_model_params['base_channels'], - max_channels=c.discriminator_model_params['max_channels'], - downsample_factors=c. - discriminator_model_params['downsample_factors']) + if c.discriminator_model in 'hifigan_discriminator': + model = MyModel() if c.discriminator_model in 'random_window_discriminator': model = MyModel( cond_channels=c.audio['num_mels'],