vocoder gan training fixes

This commit is contained in:
Eren Gölge 2021-04-09 11:38:04 +02:00
parent cd69da4868
commit 105e0b4d62
2 changed files with 19 additions and 19 deletions

View File

@ -472,14 +472,24 @@ def main(args): # pylint: disable=redefined-outer-name
model_gen = setup_generator(c)
model_disc = setup_discriminator(c)
# setup criterion
criterion_gen = GeneratorLoss(c)
criterion_disc = DiscriminatorLoss(c)
if use_cuda:
model_gen.cuda()
criterion_gen.cuda()
model_disc.cuda()
criterion_disc.cuda()
# setup optimizers
# TODO: allow loading custom optimizers
optimizer_gen = None
optimizer_disc = None
optimizer_gen = getattr(torch.optim, c.optimizer)
optimizer_gen = optimizer_gen(lr=c.lr_gen, **c.optimizer_params)
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
optimizer_disc = getattr(torch.optim, c.optimizer)
optimizer_disc = optimizer_disc(lr=c.lr_gen, **c.optimizer_params)
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
# schedulers
scheduler_gen = None
@ -493,10 +503,6 @@ def main(args): # pylint: disable=redefined-outer-name
scheduler_disc = scheduler_disc(
optimizer_disc, **c.lr_scheduler_disc_params)
# setup criterion
criterion_gen = GeneratorLoss(c)
criterion_disc = DiscriminatorLoss(c)
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
@ -533,11 +539,12 @@ def main(args): # pylint: disable=redefined-outer-name
del model_dict
# reset lr if not countinuining training.
for group in optimizer_gen.param_groups:
group['lr'] = c.lr_gen
if args.continue_path == '':
for group in optimizer_gen.param_groups:
group['lr'] = c.lr_gen
for group in optimizer_disc.param_groups:
group['lr'] = c.lr_disc
for group in optimizer_disc.param_groups:
group['lr'] = c.lr_disc
print(f" > Model restored from step {checkpoint['step']:d}",
flush=True)
@ -545,11 +552,6 @@ def main(args): # pylint: disable=redefined-outer-name
else:
args.restore_step = 0
if use_cuda:
model_gen.cuda()
criterion_gen.cuda()
model_disc.cuda()
criterion_disc.cuda()
# DISTRUBUTED
if num_gpus > 1:

View File

@ -15,8 +15,7 @@ class MelganMultiscaleDiscriminator(nn.Module):
pooling_kernel_size=4,
pooling_stride=2,
pooling_padding=2,
groups_denominator=4,
max_groups=256):
groups_denominator=4):
super(MelganMultiscaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
@ -26,8 +25,7 @@ class MelganMultiscaleDiscriminator(nn.Module):
base_channels=base_channels,
max_channels=max_channels,
downsample_factors=downsample_factors,
groups_denominator=groups_denominator,
max_groups=max_groups)
groups_denominator=groups_denominator)
for _ in range(num_scales)
])