* Format the model definition

* Update code and integrate training code
This commit is contained in:
rishikksh20 2021-02-24 08:01:51 +05:30 committed by Eren Gölge
parent 39b5845810
commit c20a6b1185
4 changed files with 57 additions and 52 deletions

View File

@ -83,7 +83,7 @@
"upsample_kernel_sizes": [16,16,4,4], "upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512, "upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11], "resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]] "resblock_dilation_sizes": [1,3,5]
}, },
// DATASET // DATASET

View File

@ -4,30 +4,31 @@ from TTS.vocoder.layers.hifigan import MRF
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, input_channel=80, hu=512, ku=[16, 16, 4, 4], kr=[3, 7, 11], Dr=[1, 3, 5]): def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4],
resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]):
super(Generator, self).__init__() super(Generator, self).__init__()
self.input = nn.Sequential( self.input = nn.Sequential(
nn.ReflectionPad1d(3), nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(input_channel, hu, kernel_size=7)) nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7))
) )
generator = [] generator = []
for k in ku: for k in upsample_kernel:
inp = hu inp = base_channels
out = int(inp / 2) out = int(inp / 2)
generator += [ generator += [
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)), nn.utils.weight_norm(nn.ConvTranspose1d(inp, out, k, k//2)),
MRF(kr, out, Dr) MRF(resblock_kernel_sizes, out, resblock_dilation_sizes)
] ]
hu = out base_channels = out
self.generator = nn.Sequential(*generator) self.generator = nn.Sequential(*generator)
self.output = nn.Sequential( self.output = nn.Sequential(
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
nn.ReflectionPad1d(3), nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(hu, 1, kernel_size=7, stride=1)), nn.utils.weight_norm(nn.Conv1d(base_channels, out_channels, kernel_size=7, stride=1)),
nn.Tanh() nn.Tanh()
) )

View File

@ -6,44 +6,31 @@ class PeriodDiscriminator(nn.Module):
def __init__(self, period): def __init__(self, period):
super(PeriodDiscriminator, self).__init__() super(PeriodDiscriminator, self).__init__()
layer = []
self.period = period self.period = period
self.discriminator = nn.ModuleList([ inp = 1
nn.Sequential( for l in range(4):
nn.utils.weight_norm(nn.Conv2d(1, 64, kernel_size=(5, 1), stride=(3, 1))), out = int(2 ** (5 + l + 1))
nn.LeakyReLU(0.2, inplace=True), layer += [
), nn.utils.weight_norm(nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))),
nn.Sequential( nn.LeakyReLU(0.2)
nn.utils.weight_norm(nn.Conv2d(64, 128, kernel_size=(5, 1), stride=(3, 1))), ]
nn.LeakyReLU(0.2, inplace=True), inp = out
), self.layer = nn.Sequential(*layer)
nn.Sequential( self.output = nn.Sequential(
nn.utils.weight_norm(nn.Conv2d(128, 256, kernel_size=(5, 1), stride=(3, 1))), nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))),
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2),
), nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))
nn.Sequential( )
nn.utils.weight_norm(nn.Conv2d(256, 512, kernel_size=(5, 1), stride=(3, 1))),
nn.LeakyReLU(0.2, inplace=True),
),
nn.Sequential(
nn.utils.weight_norm(nn.Conv2d(512, 1024, kernel_size=(5, 1))),
nn.LeakyReLU(0.2, inplace=True),
),
nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1))),
])
def forward(self, x): def forward(self, x):
batch_size = x.shape[0] batch_size = x.shape[0]
pad = self.period - (x.shape[-1] % self.period) pad = self.period - (x.shape[-1] % self.period)
x = F.pad(x, (0, pad), "reflect") x = F.pad(x, (0, pad))
y = x.view(batch_size, -1, self.period).contiguous() y = x.view(batch_size, -1, self.period).contiguous()
y = y.unsqueeze(1) y = y.unsqueeze(1)
features = list() out1 = self.layer(y)
for module in self.discriminator: return self.output(out1)
y = module(y)
features.append(y)
return features[-1], features[:-1]
class MultiPeriodDiscriminator(nn.Module): class MultiPeriodDiscriminator(nn.Module):
@ -64,26 +51,25 @@ class MultiPeriodDiscriminator(nn.Module):
PeriodDiscriminator(periods[1]), PeriodDiscriminator(periods[1]),
PeriodDiscriminator(periods[2]), PeriodDiscriminator(periods[2]),
PeriodDiscriminator(periods[3]), PeriodDiscriminator(periods[3]),
PeriodDiscriminator(periods[4]), PeriodDiscriminator(periods[4])
]) ])
self.msd = MelganMultiscaleDiscriminator( self.msd = MelganMultiscaleDiscriminator(
in_channels=1, in_channels=in_channels,
out_channels=1, out_channels=out_channels,
num_scales=3, num_scales=num_scales,
kernel_sizes=(5, 3), kernel_sizes=kernel_sizes,
base_channels=16, base_channels=base_channels,
max_channels=1024, max_channels=max_channels,
downsample_factors=(4, 4, 4), downsample_factors=downsample_factors,
pooling_kernel_size=4, pooling_kernel_size=pooling_kernel_size,
pooling_stride=2, pooling_stride=pooling_stride,
pooling_padding=1 pooling_padding=pooling_padding
) )
def forward(self, x): def forward(self, x):
scores, feats = self.msd(x) scores, feats = self.msd(x)
for key, disc in enumerate(self.discriminators): for key, disc in enumerate(self.discriminators):
score, feat = disc(x) score = disc(x)
scores.append(score) scores.append(score)
feats.append(feat)
return scores, feats return scores, feats

View File

@ -91,6 +91,14 @@ def setup_generator(c):
num_res_blocks=c.wavernn_model_params['num_res_blocks'], num_res_blocks=c.wavernn_model_params['num_res_blocks'],
hop_length=c.audio["hop_length"], hop_length=c.audio["hop_length"],
sample_rate=c.audio["sample_rate"],) sample_rate=c.audio["sample_rate"],)
elif c.generator_model.lower() in 'hifigan_generator':
model = MyModel(
in_channels=c.audio['num_mels'],
out_channels=1,
base_channels=c.generator_model_params['upsample_initial_channel'],
upsample_kernel=c.generator_model_params['upsample_kernel_sizes'],
resblock_kernel_sizes=c.generator_model_params['resblock_kernel_sizes'],
resblock_dilation_sizes=c.generator_model_params['resblock_dilation_sizes'])
elif c.generator_model.lower() in 'melgan_generator': elif c.generator_model.lower() in 'melgan_generator':
model = MyModel( model = MyModel(
in_channels=c.audio['num_mels'], in_channels=c.audio['num_mels'],
@ -162,6 +170,16 @@ def setup_discriminator(c):
MyModel = importlib.import_module('TTS.vocoder.models.' + MyModel = importlib.import_module('TTS.vocoder.models.' +
c.discriminator_model.lower()) c.discriminator_model.lower())
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
if c.discriminator_model in 'multi_period_discriminator':
model = MyModel(
periods=c.discriminator_model_params['peroids'],
in_channels=1,
out_channels=1,
kernel_sizes=(5, 3),
base_channels=c.discriminator_model_params['base_channels'],
max_channels=c.discriminator_model_params['max_channels'],
downsample_factors=c.
discriminator_model_params['downsample_factors'])
if c.discriminator_model in 'random_window_discriminator': if c.discriminator_model in 'random_window_discriminator':
model = MyModel( model = MyModel(
cond_channels=c.audio['num_mels'], cond_channels=c.audio['num_mels'],