mirror of https://github.com/coqui-ai/TTS.git
add hifigan D
This commit is contained in:
parent
13dca6e6b6
commit
7cecd2fb2e
|
@ -0,0 +1,212 @@
|
||||||
|
# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorP(torch.nn.Module):
|
||||||
|
"""HiFiGAN Periodic Discriminator
|
||||||
|
|
||||||
|
Takes every Pth value from the input waveform and applied a stack of convoluations.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
if `period` is 2
|
||||||
|
`waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[Tensor]: discriminator scores per sample in the batch.
|
||||||
|
[List[Tensor]]: list of features from each convolutional layer.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
x: [B, 1, T]
|
||||||
|
"""
|
||||||
|
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||||
|
super().__init__()
|
||||||
|
self.period = period
|
||||||
|
get_padding = lambda k, d: int((k*d - d)/2)
|
||||||
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
||||||
|
norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
||||||
|
])
|
||||||
|
self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[Tensor]: discriminator scores per sample in the batch.
|
||||||
|
[List[Tensor]]: list of features from each convolutional layer.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
x: [B, 1, T]
|
||||||
|
"""
|
||||||
|
feat = []
|
||||||
|
|
||||||
|
# 1d to 2d
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # pad first
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
feat.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
feat.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, feat
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
"""HiFiGAN Multi-Period Discriminator (MPD)
|
||||||
|
Wrapper for the `PeriodDiscriminator` to apply it in different periods.
|
||||||
|
Periods are suggested to be prime numbers to reduce the overlap between each discriminator.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
|
self.discriminators = nn.ModuleList([
|
||||||
|
DiscriminatorP(2),
|
||||||
|
DiscriminatorP(3),
|
||||||
|
DiscriminatorP(5),
|
||||||
|
DiscriminatorP(7),
|
||||||
|
DiscriminatorP(11),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[List[Tensor]]: list of scores from each discriminator.
|
||||||
|
[List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
x: [B, 1, T]
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
feats = []
|
||||||
|
for _, d in enumerate(self.discriminators):
|
||||||
|
score, feat = d(x)
|
||||||
|
scores.append(score)
|
||||||
|
feats.append(feat)
|
||||||
|
return scores, feats
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
"""HiFiGAN Scale Discriminator.
|
||||||
|
It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorS, self).__init__()
|
||||||
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
|
||||||
|
norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
||||||
|
norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||||
|
])
|
||||||
|
self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: discriminator scores.
|
||||||
|
List[Tensor]: list of features from the convolutiona layers.
|
||||||
|
"""
|
||||||
|
feat = []
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
feat.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
feat.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
return x, feat
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleDiscriminator(torch.nn.Module):
|
||||||
|
"""HiFiGAN Multi-Scale Discriminator.
|
||||||
|
It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super(MultiScaleDiscriminator, self).__init__()
|
||||||
|
self.discriminators = nn.ModuleList([
|
||||||
|
DiscriminatorS(use_spectral_norm=True),
|
||||||
|
DiscriminatorS(),
|
||||||
|
DiscriminatorS(),
|
||||||
|
])
|
||||||
|
self.meanpools = nn.ModuleList([
|
||||||
|
nn.AvgPool1d(4, 2, padding=2),
|
||||||
|
nn.AvgPool1d(4, 2, padding=2)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tensor]: discriminator scores.
|
||||||
|
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
feats = []
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
if i != 0:
|
||||||
|
x = self.meanpools[i-1](x)
|
||||||
|
score, feat = d(x)
|
||||||
|
scores.append(score)
|
||||||
|
feats.append(feat)
|
||||||
|
return scores, feats
|
||||||
|
|
||||||
|
|
||||||
|
class HifiganDiscriminator(nn.Module):
|
||||||
|
"""HiFiGAN discriminator wrapping MPD and MSD.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.mpd = MultiPeriodDiscriminator()
|
||||||
|
self.msd = MultiScaleDiscriminator()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input waveform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tensor]: discriminator scores.
|
||||||
|
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||||
|
"""
|
||||||
|
scores, feats = self.msd(x)
|
||||||
|
scores_, feats_ = self.mpd(x)
|
||||||
|
scores += scores_
|
||||||
|
feats += feats_
|
||||||
|
return scores, feats
|
|
@ -1,77 +0,0 @@
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator
|
|
||||||
|
|
||||||
|
|
||||||
class PeriodDiscriminator(nn.Module):
|
|
||||||
def __init__(self, period):
|
|
||||||
super(PeriodDiscriminator, self).__init__()
|
|
||||||
layer = []
|
|
||||||
self.period = period
|
|
||||||
inp = 1
|
|
||||||
for l in range(4):
|
|
||||||
out = int(2**(5 + l + 1))
|
|
||||||
layer += [
|
|
||||||
nn.utils.weight_norm(
|
|
||||||
nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
|
|
||||||
nn.LeakyReLU(0.2)
|
|
||||||
]
|
|
||||||
inp = out
|
|
||||||
self.layer = nn.Sequential(*layer)
|
|
||||||
self.output = nn.Sequential(
|
|
||||||
nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
|
|
||||||
nn.LeakyReLU(0.2),
|
|
||||||
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
pad = self.period - (x.shape[-1] % self.period)
|
|
||||||
x = F.pad(x, (0, pad))
|
|
||||||
y = x.view(batch_size, -1, self.period).contiguous()
|
|
||||||
y = y.unsqueeze(1)
|
|
||||||
out1 = self.layer(y)
|
|
||||||
return self.output(out1)
|
|
||||||
|
|
||||||
|
|
||||||
class HifiDiscriminator(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
periods=[2, 3, 5, 7, 11],
|
|
||||||
in_channels=1,
|
|
||||||
out_channels=1,
|
|
||||||
num_scales=3,
|
|
||||||
kernel_sizes=(5, 3),
|
|
||||||
base_channels=64,
|
|
||||||
max_channels=1024,
|
|
||||||
downsample_factors=(2, 2, 4, 4),
|
|
||||||
pooling_kernel_size=4,
|
|
||||||
pooling_stride=2,
|
|
||||||
pooling_padding=1):
|
|
||||||
super().__init__()
|
|
||||||
self.discriminators = nn.ModuleList([
|
|
||||||
PeriodDiscriminator(periods[0]),
|
|
||||||
PeriodDiscriminator(periods[1]),
|
|
||||||
PeriodDiscriminator(periods[2]),
|
|
||||||
PeriodDiscriminator(periods[3]),
|
|
||||||
PeriodDiscriminator(periods[4])
|
|
||||||
])
|
|
||||||
|
|
||||||
self.msd = MelganMultiscaleDiscriminator(
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
num_scales=num_scales,
|
|
||||||
kernel_sizes=kernel_sizes,
|
|
||||||
base_channels=base_channels,
|
|
||||||
max_channels=max_channels,
|
|
||||||
downsample_factors=downsample_factors,
|
|
||||||
pooling_kernel_size=pooling_kernel_size,
|
|
||||||
pooling_stride=pooling_stride,
|
|
||||||
pooling_padding=pooling_padding,
|
|
||||||
groups_denominator=32,
|
|
||||||
max_groups=16)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
scores, feats = self.msd(x)
|
|
||||||
for key, disc in enumerate(self.discriminators):
|
|
||||||
score = disc(x)
|
|
||||||
scores.append(score)
|
|
||||||
return scores, feats
|
|
Loading…
Reference in New Issue