From 105e0b4d62429f8618d8999631cb75724380b451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 9 Apr 2021 11:38:04 +0200 Subject: [PATCH] vocoder gan training fixes --- TTS/bin/train_vocoder_gan.py | 32 ++++++++++--------- .../models/melgan_multiscale_discriminator.py | 6 ++-- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 95cf612a..7681d660 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -472,14 +472,24 @@ def main(args): # pylint: disable=redefined-outer-name model_gen = setup_generator(c) model_disc = setup_discriminator(c) + # setup criterion + criterion_gen = GeneratorLoss(c) + criterion_disc = DiscriminatorLoss(c) + + if use_cuda: + model_gen.cuda() + criterion_gen.cuda() + model_disc.cuda() + criterion_disc.cuda() + # setup optimizers # TODO: allow loading custom optimizers optimizer_gen = None optimizer_disc = None optimizer_gen = getattr(torch.optim, c.optimizer) - optimizer_gen = optimizer_gen(lr=c.lr_gen, **c.optimizer_params) + optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) optimizer_disc = getattr(torch.optim, c.optimizer) - optimizer_disc = optimizer_disc(lr=c.lr_gen, **c.optimizer_params) + optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) # schedulers scheduler_gen = None @@ -493,10 +503,6 @@ def main(args): # pylint: disable=redefined-outer-name scheduler_disc = scheduler_disc( optimizer_disc, **c.lr_scheduler_disc_params) - # setup criterion - criterion_gen = GeneratorLoss(c) - criterion_disc = DiscriminatorLoss(c) - if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") checkpoint = torch.load(args.restore_path, map_location='cpu') @@ -533,11 +539,12 @@ def main(args): # pylint: disable=redefined-outer-name del model_dict # reset lr if not countinuining training. - for group in optimizer_gen.param_groups: - group['lr'] = c.lr_gen + if args.continue_path == '': + for group in optimizer_gen.param_groups: + group['lr'] = c.lr_gen - for group in optimizer_disc.param_groups: - group['lr'] = c.lr_disc + for group in optimizer_disc.param_groups: + group['lr'] = c.lr_disc print(f" > Model restored from step {checkpoint['step']:d}", flush=True) @@ -545,11 +552,6 @@ def main(args): # pylint: disable=redefined-outer-name else: args.restore_step = 0 - if use_cuda: - model_gen.cuda() - criterion_gen.cuda() - model_disc.cuda() - criterion_disc.cuda() # DISTRUBUTED if num_gpus > 1: diff --git a/TTS/vocoder/models/melgan_multiscale_discriminator.py b/TTS/vocoder/models/melgan_multiscale_discriminator.py index 3ab6e13c..b01ab91f 100644 --- a/TTS/vocoder/models/melgan_multiscale_discriminator.py +++ b/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -15,8 +15,7 @@ class MelganMultiscaleDiscriminator(nn.Module): pooling_kernel_size=4, pooling_stride=2, pooling_padding=2, - groups_denominator=4, - max_groups=256): + groups_denominator=4): super(MelganMultiscaleDiscriminator, self).__init__() self.discriminators = nn.ModuleList([ @@ -26,8 +25,7 @@ class MelganMultiscaleDiscriminator(nn.Module): base_channels=base_channels, max_channels=max_channels, downsample_factors=downsample_factors, - groups_denominator=groups_denominator, - max_groups=max_groups) + groups_denominator=groups_denominator) for _ in range(num_scales) ])