Add HiFi-GAN v1 generator and discriminator classes

This commit is contained in:
rishikksh20 2021-02-14 01:09:55 +05:30 committed by Eren Gölge
parent a669a492c6
commit 4493feb95c
3 changed files with 130 additions and 0 deletions

View File

@ -0,0 +1,40 @@
from torch import nn
class ResStack(nn.Module):
def __init__(self, kernel, channel, padding, dilations = [1, 3, 5]):
super(ResStack, self).__init__()
resstack = []
for dilation in dilations:
resstack += [
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(dilation),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)),
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(padding),
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
]
self.resstack = nn.Sequential(*resstack)
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
def forward(self, x):
x1 = self.shortcut(x)
x2 = self.resstack(x)
return x1 + x2
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.shortcut)
class MRF(nn.Module):
def __init__(self, kernels, channel, dilations = [[1,1], [3,1], [5,1]]):
super(MRF, self).__init__()
self.resblock1 = ResStack(kernels[0], channel, 0)
self.resblock2 = ResStack(kernels[1], channel, 6)
self.resblock3 = ResStack(kernels[2], channel, 12)
def forward(self, x):
x1 = self.resblock1(x)
x2 = self.resblock2(x)
x3 = self.resblock3(x)
return x1 + x2 + x3

View File

@ -0,0 +1,39 @@
from torch import nn
from TTS.vocoder.layers.hifigan import MRF
class Generator(nn.Module):
def __init__(self, input_channel=80, hu=512, ku=[16, 16, 4, 4], kr=[3, 7, 11], Dr=[1, 3, 5]):
super(Generator, self).__init__()
self.input = nn.Sequential(
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(input_channel, hu, kernel_size=7))
)
generator = []
for k in ku:
inp = hu
out = int(inp / 2)
generator += [
nn.LeakyReLU(0.2),
nn.ConvTranspose1d(inp, out, k, k // 2),
MRF(kr, out, Dr)
]
hu = out
self.generator = nn.Sequential(*generator)
self.output = nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(hu, 1, kernel_size=7, stride=1)),
nn.Tanh()
)
def forward(self, x):
x1 = self.input(x)
x2 = self.generator(x1)
out = self.output(x2)
return out

View File

@ -0,0 +1,51 @@
from torch import nn
import torch.nn.functional as F
class PeriodDiscriminator(nn.Module):
def __init__(self, period):
super(PeriodDiscriminator, self).__init__()
layer = []
self.period = period
inp = 1
for l in range(4):
out = int(2 ** (5 + l + 1))
layer += [
nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
nn.LeakyReLU(0.2)
]
inp = out
self.layer = nn.Sequential(*layer)
self.output = nn.Sequential(
nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))
)
def forward(self, x):
batch_size = x.shape[0]
pad = self.period - (x.shape[-1] % self.period)
x = F.pad(x, (0, pad), "reflect")
y = x.view(batch_size, -1, self.period).contiguous()
y = y.unsqueeze(1)
out1 = self.layer(y)
return self.output(out1)
class MPD(nn.Module):
def __init__(self, periods=[2, 3, 5, 7, 11], segment_length=16000):
super(MPD, self).__init__()
self.mpd1 = PeriodDiscriminator(periods[0])
self.mpd2 = PeriodDiscriminator(periods[1])
self.mpd3 = PeriodDiscriminator(periods[2])
self.mpd4 = PeriodDiscriminator(periods[3])
self.mpd5 = PeriodDiscriminator(periods[4])
def forward(self, x):
out1 = self.mpd1(x)
out2 = self.mpd2(x)
out3 = self.mpd3(x)
out4 = self.mpd4(x)
out5 = self.mpd5(x)
return out1, out2, out3, out4, out5