mirror of https://github.com/coqui-ai/TTS.git
90 lines
3.3 KiB
Python
90 lines
3.3 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.nn.modules.conv import Conv1d
|
|
|
|
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
|
|
|
|
|
class DiscriminatorS(torch.nn.Module):
|
|
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.
|
|
|
|
Args:
|
|
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
|
"""
|
|
|
|
def __init__(self, use_spectral_norm=False, upsampling_rates=[4, 4, 4, 4]):
|
|
super().__init__()
|
|
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
|
self.convs = nn.ModuleList([norm_f(Conv1d(1, 16, 15, 1, padding=7))])
|
|
groups = 4
|
|
in_channels = 16
|
|
out_channels = 64
|
|
for rate in upsampling_rates:
|
|
self.convs.append(norm_f(Conv1d(in_channels, out_channels, 41, rate, groups=groups, padding=20)))
|
|
groups = min(groups * rate, 256)
|
|
in_channels = min(in_channels * rate, 1024)
|
|
out_channels = min(out_channels * rate, 1024)
|
|
self.convs += [norm_f(Conv1d(1024, 1024, 5, 1, padding=2))]
|
|
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x (Tensor): input waveform.
|
|
|
|
Returns:
|
|
Tensor: discriminator scores.
|
|
List[Tensor]: list of features from the convolutiona layers.
|
|
"""
|
|
feat = []
|
|
for l in self.convs:
|
|
x = l(x)
|
|
x = torch.nn.functional.leaky_relu(x, 0.1)
|
|
feat.append(x)
|
|
x = self.conv_post(x)
|
|
feat.append(x)
|
|
x = torch.flatten(x, 1, -1)
|
|
return x, feat
|
|
|
|
|
|
class VitsDiscriminator(nn.Module):
|
|
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.
|
|
|
|
::
|
|
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|
|
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^
|
|
|
|
Args:
|
|
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
|
"""
|
|
|
|
def __init__(self, use_spectral_norm=False, periods=[2, 3, 5, 7, 11], upsampling_rates=[4,4,4,4]):
|
|
super().__init__()
|
|
self.nets = nn.ModuleList()
|
|
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm, upsampling_rates=upsampling_rates))
|
|
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
|
|
|
|
def forward(self, x, x_hat=None):
|
|
"""
|
|
Args:
|
|
x (Tensor): ground truth waveform.
|
|
x_hat (Tensor): predicted waveform.
|
|
|
|
Returns:
|
|
List[Tensor]: discriminator scores.
|
|
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
|
"""
|
|
x_scores = []
|
|
x_hat_scores = [] if x_hat is not None else None
|
|
x_feats = []
|
|
x_hat_feats = [] if x_hat is not None else None
|
|
for net in self.nets:
|
|
x_score, x_feat = net(x)
|
|
x_scores.append(x_score)
|
|
x_feats.append(x_feat)
|
|
if x_hat is not None:
|
|
x_hat_score, x_hat_feat = net(x_hat)
|
|
x_hat_scores.append(x_hat_score)
|
|
x_hat_feats.append(x_hat_feat)
|
|
return x_scores, x_feats, x_hat_scores, x_hat_feats
|