From 7b7c5d635fc42be520f86524c6d3c79d7ef48d39 Mon Sep 17 00:00:00 2001 From: rishikksh20 Date: Thu, 18 Feb 2021 01:26:58 +0530 Subject: [PATCH] 1) Combine MSD with Multi-Period disc 2) Add remove weight norm layer on Generator --- TTS/vocoder/layers/hifigan.py | 15 +++- TTS/vocoder/models/hifigan_generator.py | 25 +++++- .../models/hifigan_mpd_discriminator.py | 80 +++++++++++-------- 3 files changed, 85 insertions(+), 35 deletions(-) diff --git a/TTS/vocoder/layers/hifigan.py b/TTS/vocoder/layers/hifigan.py index 450253ee..a60f96be 100644 --- a/TTS/vocoder/layers/hifigan.py +++ b/TTS/vocoder/layers/hifigan.py @@ -24,6 +24,14 @@ class ResStack(nn.Module): return x1 + x2 def remove_weight_norm(self): + # nn.utils.remove_weight_norm(self.resstack[2]) + # nn.utils.remove_weight_norm(self.resstack[4]) + for idx, layer in enumerate(self.resstack): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except: + layer.remove_weight_norm() nn.utils.remove_weight_norm(self.shortcut) class MRF(nn.Module): @@ -37,4 +45,9 @@ class MRF(nn.Module): x1 = self.resblock1(x) x2 = self.resblock2(x) x3 = self.resblock3(x) - return x1 + x2 + x3 \ No newline at end of file + return x1 + x2 + x3 + + def remove_weight_norm(self): + self.resblock1.remove_weight_norm() + self.resblock2.remove_weight_norm() + self.resblock3.remove_weight_norm() \ No newline at end of file diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 5ca3cf75..24aed482 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -18,7 +18,7 @@ class Generator(nn.Module): out = int(inp / 2) generator += [ nn.LeakyReLU(0.2), - nn.ConvTranspose1d(inp, out, k, k // 2), + nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)), MRF(kr, out, Dr) ] hu = out @@ -36,4 +36,25 @@ class Generator(nn.Module): x1 = self.input(x) x2 = self.generator(x1) out = self.output(x2) - return out \ No newline at end of file + return out + + def remove_weight_norm(self): + for idx, layer in enumerate(self.input): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except: + layer.remove_weight_norm() + + for idx, layer in enumerate(self.output): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except: + layer.remove_weight_norm() + for idx, layer in enumerate(self.generator): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except: + layer.remove_weight_norm() \ No newline at end of file diff --git a/TTS/vocoder/models/hifigan_mpd_discriminator.py b/TTS/vocoder/models/hifigan_mpd_discriminator.py index 2a095c2e..84891b4e 100644 --- a/TTS/vocoder/models/hifigan_mpd_discriminator.py +++ b/TTS/vocoder/models/hifigan_mpd_discriminator.py @@ -1,27 +1,37 @@ from torch import nn import torch.nn.functional as F - +from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator class PeriodDiscriminator(nn.Module): def __init__(self, period): super(PeriodDiscriminator, self).__init__() - layer = [] + self.period = period - inp = 1 - for l in range(4): - out = int(2 ** (5 + l + 1)) - layer += [ - nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))), - nn.LeakyReLU(0.2) - ] - inp = out - self.layer = nn.Sequential(*layer) - self.output = nn.Sequential( - nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))), - nn.LeakyReLU(0.2), - nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))) - ) + self.discriminator = nn.ModuleList([ + nn.Sequential( + nn.utils.weight_norm(nn.Conv2d(1, 64, kernel_size=(5, 1), stride=(3, 1))), + nn.LeakyReLU(0.2, inplace=True), + ), + nn.Sequential( + nn.utils.weight_norm(nn.Conv2d(64, 128, kernel_size=(5, 1), stride=(3, 1))), + nn.LeakyReLU(0.2, inplace=True), + ), + nn.Sequential( + nn.utils.weight_norm(nn.Conv2d(128, 256, kernel_size=(5, 1), stride=(3, 1))), + nn.LeakyReLU(0.2, inplace=True), + ), + nn.Sequential( + nn.utils.weight_norm(nn.Conv2d(256, 512, kernel_size=(5, 1), stride=(3, 1))), + nn.LeakyReLU(0.2, inplace=True), + ), + nn.Sequential( + nn.utils.weight_norm(nn.Conv2d(512, 1024, kernel_size=(5, 1))), + nn.LeakyReLU(0.2, inplace=True), + ), + nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))), + ]) + def forward(self, x): batch_size = x.shape[0] @@ -29,23 +39,29 @@ class PeriodDiscriminator(nn.Module): x = F.pad(x, (0, pad), "reflect") y = x.view(batch_size, -1, self.period).contiguous() y = y.unsqueeze(1) - out1 = self.layer(y) - return self.output(out1) + features = list() + for module in self.discriminator: + y = module(y) + features.append(y) + return features[-1], features[:-1] -class MPD(nn.Module): - def __init__(self, periods=[2, 3, 5, 7, 11], segment_length=16000): - super(MPD, self).__init__() - self.mpd1 = PeriodDiscriminator(periods[0]) - self.mpd2 = PeriodDiscriminator(periods[1]) - self.mpd3 = PeriodDiscriminator(periods[2]) - self.mpd4 = PeriodDiscriminator(periods[3]) - self.mpd5 = PeriodDiscriminator(periods[4]) +class HiFiDiscriminator(nn.Module): + def __init__(self, periods=[2, 3, 5, 7, 11]): + super(HiFiDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ PeriodDiscriminator(periods[0]), + PeriodDiscriminator(periods[1]), + PeriodDiscriminator(periods[2]), + PeriodDiscriminator(periods[3]), + PeriodDiscriminator(periods[4]), + ]) + + self.msd = MelganMultiscaleDiscriminator() def forward(self, x): - out1 = self.mpd1(x) - out2 = self.mpd2(x) - out3 = self.mpd3(x) - out4 = self.mpd4(x) - out5 = self.mpd5(x) - return out1, out2, out3, out4, out5 + scores, feats = self.msd(x) + for key, disc in enumerate(self.discriminators): + score, feat = disc(x) + scores.append(score) + feats.append(feat) + return scores, feats