mirror of https://github.com/coqui-ai/TTS.git
Make hifigan discriminator configurable
This commit is contained in:
parent
c437db15fd
commit
8e915b70e0
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn.modules.conv import Conv1d
|
||||
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
|
@ -12,19 +12,19 @@ class DiscriminatorS(torch.nn.Module):
|
|||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
"""
|
||||
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
def __init__(self, use_spectral_norm=False, upsampling_rates=[4, 4, 4, 4]):
|
||||
super().__init__()
|
||||
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.convs = nn.ModuleList([norm_f(Conv1d(1, 16, 15, 1, padding=7))])
|
||||
groups = 4
|
||||
in_channels = 16
|
||||
out_channels = 64
|
||||
for rate in upsampling_rates:
|
||||
self.convs.append(norm_f(Conv1d(in_channels, out_channels, 41, rate, groups=groups, padding=20)))
|
||||
groups = min(groups * rate, 256)
|
||||
in_channels = min(in_channels * rate, 1024)
|
||||
out_channels = min(out_channels * rate, 1024)
|
||||
self.convs += [norm_f(Conv1d(1024, 1024, 5, 1, padding=2))]
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -58,10 +58,10 @@ class VitsDiscriminator(nn.Module):
|
|||
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
|
||||
"""
|
||||
|
||||
def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
|
||||
def __init__(self, use_spectral_norm=False, periods=[2, 3, 5, 7, 11], upsampling_rates=[4,4,4,4]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList()
|
||||
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
|
||||
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm, upsampling_rates=upsampling_rates))
|
||||
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
|
||||
|
||||
def forward(self, x, x_hat=None):
|
||||
|
|
Loading…
Reference in New Issue