From 4493feb95ca8894479290bf5d704ddf613131f72 Mon Sep 17 00:00:00 2001 From: rishikksh20 Date: Sun, 14 Feb 2021 01:09:55 +0530 Subject: [PATCH] Add HiFi-GAN v1 generator and discriminator classes --- TTS/vocoder/layers/hifigan.py | 40 +++++++++++++++ TTS/vocoder/models/hifigan_generator.py | 39 ++++++++++++++ .../models/hifigan_mpd_discriminator.py | 51 +++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100644 TTS/vocoder/layers/hifigan.py create mode 100644 TTS/vocoder/models/hifigan_generator.py create mode 100644 TTS/vocoder/models/hifigan_mpd_discriminator.py diff --git a/TTS/vocoder/layers/hifigan.py b/TTS/vocoder/layers/hifigan.py new file mode 100644 index 00000000..450253ee --- /dev/null +++ b/TTS/vocoder/layers/hifigan.py @@ -0,0 +1,40 @@ +from torch import nn + + +class ResStack(nn.Module): + def __init__(self, kernel, channel, padding, dilations = [1, 3, 5]): + super(ResStack, self).__init__() + resstack = [] + for dilation in dilations: + resstack += [ + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(dilation), + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)), + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(padding), + nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), + ] + self.resstack = nn.Sequential(*resstack) + + self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) + + def forward(self, x): + x1 = self.shortcut(x) + x2 = self.resstack(x) + return x1 + x2 + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.shortcut) + +class MRF(nn.Module): + def __init__(self, kernels, channel, dilations = [[1,1], [3,1], [5,1]]): + super(MRF, self).__init__() + self.resblock1 = ResStack(kernels[0], channel, 0) + self.resblock2 = ResStack(kernels[1], channel, 6) + self.resblock3 = ResStack(kernels[2], channel, 12) + + def forward(self, x): + x1 = self.resblock1(x) + x2 = self.resblock2(x) + x3 = self.resblock3(x) + return x1 + x2 + x3 \ No newline at end of file diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py new file mode 100644 index 00000000..5ca3cf75 --- /dev/null +++ b/TTS/vocoder/models/hifigan_generator.py @@ -0,0 +1,39 @@ +from torch import nn +from TTS.vocoder.layers.hifigan import MRF + + +class Generator(nn.Module): + + def __init__(self, input_channel=80, hu=512, ku=[16, 16, 4, 4], kr=[3, 7, 11], Dr=[1, 3, 5]): + super(Generator, self).__init__() + self.input = nn.Sequential( + nn.ReflectionPad1d(3), + nn.utils.weight_norm(nn.Conv1d(input_channel, hu, kernel_size=7)) + ) + + generator = [] + + for k in ku: + inp = hu + out = int(inp / 2) + generator += [ + nn.LeakyReLU(0.2), + nn.ConvTranspose1d(inp, out, k, k // 2), + MRF(kr, out, Dr) + ] + hu = out + self.generator = nn.Sequential(*generator) + + self.output = nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(3), + nn.utils.weight_norm(nn.Conv1d(hu, 1, kernel_size=7, stride=1)), + nn.Tanh() + + ) + + def forward(self, x): + x1 = self.input(x) + x2 = self.generator(x1) + out = self.output(x2) + return out \ No newline at end of file diff --git a/TTS/vocoder/models/hifigan_mpd_discriminator.py b/TTS/vocoder/models/hifigan_mpd_discriminator.py new file mode 100644 index 00000000..2a095c2e --- /dev/null +++ b/TTS/vocoder/models/hifigan_mpd_discriminator.py @@ -0,0 +1,51 @@ +from torch import nn +import torch.nn.functional as F + + +class PeriodDiscriminator(nn.Module): + + def __init__(self, period): + super(PeriodDiscriminator, self).__init__() + layer = [] + self.period = period + inp = 1 + for l in range(4): + out = int(2 ** (5 + l + 1)) + layer += [ + nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))), + nn.LeakyReLU(0.2) + ] + inp = out + self.layer = nn.Sequential(*layer) + self.output = nn.Sequential( + nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))), + nn.LeakyReLU(0.2), + nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))) + ) + + def forward(self, x): + batch_size = x.shape[0] + pad = self.period - (x.shape[-1] % self.period) + x = F.pad(x, (0, pad), "reflect") + y = x.view(batch_size, -1, self.period).contiguous() + y = y.unsqueeze(1) + out1 = self.layer(y) + return self.output(out1) + + +class MPD(nn.Module): + def __init__(self, periods=[2, 3, 5, 7, 11], segment_length=16000): + super(MPD, self).__init__() + self.mpd1 = PeriodDiscriminator(periods[0]) + self.mpd2 = PeriodDiscriminator(periods[1]) + self.mpd3 = PeriodDiscriminator(periods[2]) + self.mpd4 = PeriodDiscriminator(periods[3]) + self.mpd5 = PeriodDiscriminator(periods[4]) + + def forward(self, x): + out1 = self.mpd1(x) + out2 = self.mpd2(x) + out3 = self.mpd3(x) + out4 = self.mpd4(x) + out5 = self.mpd5(x) + return out1, out2, out3, out4, out5