diff --git a/vocoder/train.py b/vocoder/train.py index 259c5eb7..c563dff0 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -107,10 +107,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, global_step += 1 - # get current learning rates - current_lr_G = list(optimizer_G.param_groups)[0]['lr'] - current_lr_D = list(optimizer_D.param_groups)[0]['lr'] - ############################## # GENERATOR ############################## @@ -166,9 +162,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad) optimizer_G.step() - - # setup lr - if c.noam_schedule: + if scheduler_G is not None: scheduler_G.step() loss_dict = dict() @@ -221,9 +215,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad) optimizer_D.step() - - # setup lr - if c.noam_schedule: + if c.scheduler_D is not None: scheduler_D.step() for key, value in loss_D_dict.items(): @@ -232,6 +224,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, step_time = time.time() - start_time epoch_time += step_time + # get current learning rates + current_lr_G = list(optimizer_G.param_groups)[0]['lr'] + current_lr_D = list(optimizer_D.param_groups)[0]['lr'] + # update avg stats update_train_values = dict() for key, value in loss_dict.items(): @@ -244,7 +240,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, if global_step % c.print_step == 0: c_logger.print_train_step(batch_n_iter, num_iter, global_step, step_time, loader_time, current_lr_G, - loss_dict, keep_avg.avg_values) + current_lr_D, loss_dict, + keep_avg.avg_values) # plot step stats if global_step % 10 == 0: @@ -262,8 +259,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, # save model save_checkpoint(model_G, optimizer_G, + scheduler_G, model_D, optimizer_D, + scheduler_D, global_step, epoch, OUT_PATH, @@ -434,6 +433,7 @@ def main(args): # pylint: disable=redefined-outer-name # setup audio processor ap = AudioProcessor(**c.audio) + # DISTRUBUTED # if num_gpus > 1: # init_distributed(args.rank, num_gpus, args.group_id, @@ -449,6 +449,12 @@ def main(args): # pylint: disable=redefined-outer-name lr=c.lr_disc, weight_decay=0) + # schedulers + scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) + scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) + scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) + scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) + # setup criterion criterion_gen = GeneratorLoss(c) criterion_disc = DiscriminatorLoss(c) @@ -456,12 +462,26 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') try: + print(" > Restoring Generator Model...") model_gen.load_state_dict(checkpoint['model']) + print(" > Restoring Generator Optimizer...") optimizer_gen.load_state_dict(checkpoint['optimizer']) + print(" > Restoring Discriminator Model...") model_disc.load_state_dict(checkpoint['model_disc']) + print(" > Restoring Discriminator Optimizer...") optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) + if 'scheduler' in checkpoint: + print(" > Restoring Generator LR Scheduler...") + scheduler_gen.load_state_dict(checkpoint['scheduler']) + # NOTE: Not sure if necessary + scheduler_gen.optimizer = optimizer_gen + if 'scheduler_disc' in checkpoint: + print(" > Restoring Discriminator LR Scheduler...") + scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) + scheduler_disc.optimizer = optimizer_disc except RuntimeError: - print(" > Partial model initialization.") + # retore only matching layers. + print(" > Partial model initialization...") model_dict = model_gen.state_dict() model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_gen.load_state_dict(model_dict) @@ -494,16 +514,6 @@ def main(args): # pylint: disable=redefined-outer-name # if num_gpus > 1: # model = apply_gradient_allreduce(model) - if c.noam_schedule: - scheduler_gen = NoamLR(optimizer_gen, - warmup_steps=c.warmup_steps_gen, - last_epoch=args.restore_step - 1) - scheduler_disc = NoamLR(optimizer_disc, - warmup_steps=c.warmup_steps_gen, - last_epoch=args.restore_step - 1) - else: - scheduler_gen, scheduler_disc = None, None - num_params = count_parameters(model_gen) print(" > Generator has {} parameters".format(num_params), flush=True) num_params = count_parameters(model_disc) @@ -526,9 +536,11 @@ def main(args): # pylint: disable=redefined-outer-name best_loss = save_best_model(target_loss, best_loss, model_gen, + scheduler_gen, optimizer_gen, model_disc, optimizer_disc, + scheduler_disc, global_step, epoch, OUT_PATH,