mirror of https://github.com/coqui-ai/TTS.git
* Format the model definition
* Update code and integrate training code
This commit is contained in:
parent
39b5845810
commit
c20a6b1185
|
@ -83,7 +83,7 @@
|
|||
"upsample_kernel_sizes": [16,16,4,4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]]
|
||||
"resblock_dilation_sizes": [1,3,5]
|
||||
},
|
||||
|
||||
// DATASET
|
||||
|
|
|
@ -4,30 +4,31 @@ from TTS.vocoder.layers.hifigan import MRF
|
|||
|
||||
class Generator(nn.Module):
|
||||
|
||||
def __init__(self, input_channel=80, hu=512, ku=[16, 16, 4, 4], kr=[3, 7, 11], Dr=[1, 3, 5]):
|
||||
def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4],
|
||||
resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]):
|
||||
super(Generator, self).__init__()
|
||||
self.input = nn.Sequential(
|
||||
nn.ReflectionPad1d(3),
|
||||
nn.utils.weight_norm(nn.Conv1d(input_channel, hu, kernel_size=7))
|
||||
nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7))
|
||||
)
|
||||
|
||||
generator = []
|
||||
|
||||
for k in ku:
|
||||
inp = hu
|
||||
for k in upsample_kernel:
|
||||
inp = base_channels
|
||||
out = int(inp / 2)
|
||||
generator += [
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)),
|
||||
MRF(kr, out, Dr)
|
||||
MRF(resblock_kernel_sizes, out, resblock_dilation_sizes)
|
||||
]
|
||||
hu = out
|
||||
base_channels = out
|
||||
self.generator = nn.Sequential(*generator)
|
||||
|
||||
self.output = nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(3),
|
||||
nn.utils.weight_norm(nn.Conv1d(hu, 1, kernel_size=7, stride=1)),
|
||||
nn.utils.weight_norm(nn.Conv1d(base_channels, out_channels, kernel_size=7, stride=1)),
|
||||
nn.Tanh()
|
||||
|
||||
)
|
||||
|
|
|
@ -6,44 +6,31 @@ class PeriodDiscriminator(nn.Module):
|
|||
|
||||
def __init__(self, period):
|
||||
super(PeriodDiscriminator, self).__init__()
|
||||
|
||||
layer = []
|
||||
self.period = period
|
||||
self.discriminator = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv2d(1, 64, kernel_size=(5, 1), stride=(3, 1))),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv2d(64, 128, kernel_size=(5, 1), stride=(3, 1))),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv2d(128, 256, kernel_size=(5, 1), stride=(3, 1))),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv2d(256, 512, kernel_size=(5, 1), stride=(3, 1))),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv2d(512, 1024, kernel_size=(5, 1))),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))),
|
||||
])
|
||||
|
||||
inp = 1
|
||||
for l in range(4):
|
||||
out = int(2 ** (5 + l + 1))
|
||||
layer += [
|
||||
nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
|
||||
nn.LeakyReLU(0.2)
|
||||
]
|
||||
inp = out
|
||||
self.layer = nn.Sequential(*layer)
|
||||
self.output = nn.Sequential(
|
||||
nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
pad = self.period - (x.shape[-1] % self.period)
|
||||
x = F.pad(x, (0, pad), "reflect")
|
||||
x = F.pad(x, (0, pad))
|
||||
y = x.view(batch_size, -1, self.period).contiguous()
|
||||
y = y.unsqueeze(1)
|
||||
features = list()
|
||||
for module in self.discriminator:
|
||||
y = module(y)
|
||||
features.append(y)
|
||||
return features[-1], features[:-1]
|
||||
out1 = self.layer(y)
|
||||
return self.output(out1)
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(nn.Module):
|
||||
|
@ -64,26 +51,25 @@ class MultiPeriodDiscriminator(nn.Module):
|
|||
PeriodDiscriminator(periods[1]),
|
||||
PeriodDiscriminator(periods[2]),
|
||||
PeriodDiscriminator(periods[3]),
|
||||
PeriodDiscriminator(periods[4]),
|
||||
PeriodDiscriminator(periods[4])
|
||||
])
|
||||
|
||||
self.msd = MelganMultiscaleDiscriminator(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
num_scales=3,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4),
|
||||
pooling_kernel_size=4,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
num_scales=num_scales,
|
||||
kernel_sizes=kernel_sizes,
|
||||
base_channels=base_channels,
|
||||
max_channels=max_channels,
|
||||
downsample_factors=downsample_factors,
|
||||
pooling_kernel_size=pooling_kernel_size,
|
||||
pooling_stride=pooling_stride,
|
||||
pooling_padding=pooling_padding
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores, feats = self.msd(x)
|
||||
for key, disc in enumerate(self.discriminators):
|
||||
score, feat = disc(x)
|
||||
score = disc(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
return scores, feats
|
||||
|
|
|
@ -91,6 +91,14 @@ def setup_generator(c):
|
|||
num_res_blocks=c.wavernn_model_params['num_res_blocks'],
|
||||
hop_length=c.audio["hop_length"],
|
||||
sample_rate=c.audio["sample_rate"],)
|
||||
elif c.generator_model.lower() in 'hifigan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=1,
|
||||
base_channels=c.generator_model_params['upsample_initial_channel'],
|
||||
upsample_kernel=c.generator_model_params['upsample_kernel_sizes'],
|
||||
resblock_kernel_sizes=c.generator_model_params['resblock_kernel_sizes'],
|
||||
resblock_dilation_sizes=c.generator_model_params['resblock_dilation_sizes'])
|
||||
elif c.generator_model.lower() in 'melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
|
@ -162,6 +170,16 @@ def setup_discriminator(c):
|
|||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in 'multi_period_discriminator':
|
||||
model = MyModel(
|
||||
periods=c.discriminator_model_params['peroids'],
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params['base_channels'],
|
||||
max_channels=c.discriminator_model_params['max_channels'],
|
||||
downsample_factors=c.
|
||||
discriminator_model_params['downsample_factors'])
|
||||
if c.discriminator_model in 'random_window_discriminator':
|
||||
model = MyModel(
|
||||
cond_channels=c.audio['num_mels'],
|
||||
|
|
Loading…
Reference in New Issue