mirror of https://github.com/coqui-ai/TTS.git
Add HiFi-GAN v1 generator and discriminator classes
This commit is contained in:
parent
a669a492c6
commit
4493feb95c
|
@ -0,0 +1,40 @@
|
|||
from torch import nn
|
||||
|
||||
|
||||
class ResStack(nn.Module):
|
||||
def __init__(self, kernel, channel, padding, dilations = [1, 3, 5]):
|
||||
super(ResStack, self).__init__()
|
||||
resstack = []
|
||||
for dilation in dilations:
|
||||
resstack += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(dilation),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation)),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(padding),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)),
|
||||
]
|
||||
self.resstack = nn.Sequential(*resstack)
|
||||
|
||||
self.shortcut = nn.utils.weight_norm(nn.Conv1d(channel, channel, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.shortcut(x)
|
||||
x2 = self.resstack(x)
|
||||
return x1 + x2
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.shortcut)
|
||||
|
||||
class MRF(nn.Module):
|
||||
def __init__(self, kernels, channel, dilations = [[1,1], [3,1], [5,1]]):
|
||||
super(MRF, self).__init__()
|
||||
self.resblock1 = ResStack(kernels[0], channel, 0)
|
||||
self.resblock2 = ResStack(kernels[1], channel, 6)
|
||||
self.resblock3 = ResStack(kernels[2], channel, 12)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.resblock1(x)
|
||||
x2 = self.resblock2(x)
|
||||
x3 = self.resblock3(x)
|
||||
return x1 + x2 + x3
|
|
@ -0,0 +1,39 @@
|
|||
from torch import nn
|
||||
from TTS.vocoder.layers.hifigan import MRF
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
|
||||
def __init__(self, input_channel=80, hu=512, ku=[16, 16, 4, 4], kr=[3, 7, 11], Dr=[1, 3, 5]):
|
||||
super(Generator, self).__init__()
|
||||
self.input = nn.Sequential(
|
||||
nn.ReflectionPad1d(3),
|
||||
nn.utils.weight_norm(nn.Conv1d(input_channel, hu, kernel_size=7))
|
||||
)
|
||||
|
||||
generator = []
|
||||
|
||||
for k in ku:
|
||||
inp = hu
|
||||
out = int(inp / 2)
|
||||
generator += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ConvTranspose1d(inp, out, k, k // 2),
|
||||
MRF(kr, out, Dr)
|
||||
]
|
||||
hu = out
|
||||
self.generator = nn.Sequential(*generator)
|
||||
|
||||
self.output = nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(3),
|
||||
nn.utils.weight_norm(nn.Conv1d(hu, 1, kernel_size=7, stride=1)),
|
||||
nn.Tanh()
|
||||
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.input(x)
|
||||
x2 = self.generator(x1)
|
||||
out = self.output(x2)
|
||||
return out
|
|
@ -0,0 +1,51 @@
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PeriodDiscriminator(nn.Module):
|
||||
|
||||
def __init__(self, period):
|
||||
super(PeriodDiscriminator, self).__init__()
|
||||
layer = []
|
||||
self.period = period
|
||||
inp = 1
|
||||
for l in range(4):
|
||||
out = int(2 ** (5 + l + 1))
|
||||
layer += [
|
||||
nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
|
||||
nn.LeakyReLU(0.2)
|
||||
]
|
||||
inp = out
|
||||
self.layer = nn.Sequential(*layer)
|
||||
self.output = nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
pad = self.period - (x.shape[-1] % self.period)
|
||||
x = F.pad(x, (0, pad), "reflect")
|
||||
y = x.view(batch_size, -1, self.period).contiguous()
|
||||
y = y.unsqueeze(1)
|
||||
out1 = self.layer(y)
|
||||
return self.output(out1)
|
||||
|
||||
|
||||
class MPD(nn.Module):
|
||||
def __init__(self, periods=[2, 3, 5, 7, 11], segment_length=16000):
|
||||
super(MPD, self).__init__()
|
||||
self.mpd1 = PeriodDiscriminator(periods[0])
|
||||
self.mpd2 = PeriodDiscriminator(periods[1])
|
||||
self.mpd3 = PeriodDiscriminator(periods[2])
|
||||
self.mpd4 = PeriodDiscriminator(periods[3])
|
||||
self.mpd5 = PeriodDiscriminator(periods[4])
|
||||
|
||||
def forward(self, x):
|
||||
out1 = self.mpd1(x)
|
||||
out2 = self.mpd2(x)
|
||||
out3 = self.mpd3(x)
|
||||
out4 = self.mpd4(x)
|
||||
out5 = self.mpd5(x)
|
||||
return out1, out2, out3, out4, out5
|
Loading…
Reference in New Issue