mirror of https://github.com/coqui-ai/TTS.git
51 lines
1.4 KiB
Python
51 lines
1.4 KiB
Python
from torch import nn
|
|
|
|
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
|
|
|
|
|
|
class MelganMultiscaleDiscriminator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=1,
|
|
out_channels=1,
|
|
num_scales=3,
|
|
kernel_sizes=(5, 3),
|
|
base_channels=16,
|
|
max_channels=1024,
|
|
downsample_factors=(4, 4, 4),
|
|
pooling_kernel_size=4,
|
|
pooling_stride=2,
|
|
pooling_padding=2,
|
|
groups_denominator=4,
|
|
):
|
|
super().__init__()
|
|
|
|
self.discriminators = nn.ModuleList(
|
|
[
|
|
MelganDiscriminator(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_sizes=kernel_sizes,
|
|
base_channels=base_channels,
|
|
max_channels=max_channels,
|
|
downsample_factors=downsample_factors,
|
|
groups_denominator=groups_denominator,
|
|
)
|
|
for _ in range(num_scales)
|
|
]
|
|
)
|
|
|
|
self.pooling = nn.AvgPool1d(
|
|
kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False
|
|
)
|
|
|
|
def forward(self, x):
|
|
scores = []
|
|
feats = []
|
|
for disc in self.discriminators:
|
|
score, feat = disc(x)
|
|
scores.append(score)
|
|
feats.append(feat)
|
|
x = self.pooling(x)
|
|
return scores, feats
|