add lr schedulers for generator and discriminator

This commit is contained in:
erogol 2020-06-09 23:03:37 +02:00
parent 8552c1d991
commit f0144bfcba
1 changed files with 34 additions and 22 deletions

View File

@ -107,10 +107,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
global_step += 1 global_step += 1
# get current learning rates
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
############################## ##############################
# GENERATOR # GENERATOR
############################## ##############################
@ -166,9 +162,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
torch.nn.utils.clip_grad_norm_(model_G.parameters(), torch.nn.utils.clip_grad_norm_(model_G.parameters(),
c.gen_clip_grad) c.gen_clip_grad)
optimizer_G.step() optimizer_G.step()
if scheduler_G is not None:
# setup lr
if c.noam_schedule:
scheduler_G.step() scheduler_G.step()
loss_dict = dict() loss_dict = dict()
@ -221,9 +215,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
torch.nn.utils.clip_grad_norm_(model_D.parameters(), torch.nn.utils.clip_grad_norm_(model_D.parameters(),
c.disc_clip_grad) c.disc_clip_grad)
optimizer_D.step() optimizer_D.step()
if c.scheduler_D is not None:
# setup lr
if c.noam_schedule:
scheduler_D.step() scheduler_D.step()
for key, value in loss_D_dict.items(): for key, value in loss_D_dict.items():
@ -232,6 +224,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
# get current learning rates
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
# update avg stats # update avg stats
update_train_values = dict() update_train_values = dict()
for key, value in loss_dict.items(): for key, value in loss_dict.items():
@ -244,7 +240,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
if global_step % c.print_step == 0: if global_step % c.print_step == 0:
c_logger.print_train_step(batch_n_iter, num_iter, global_step, c_logger.print_train_step(batch_n_iter, num_iter, global_step,
step_time, loader_time, current_lr_G, step_time, loader_time, current_lr_G,
loss_dict, keep_avg.avg_values) current_lr_D, loss_dict,
keep_avg.avg_values)
# plot step stats # plot step stats
if global_step % 10 == 0: if global_step % 10 == 0:
@ -262,8 +259,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# save model # save model
save_checkpoint(model_G, save_checkpoint(model_G,
optimizer_G, optimizer_G,
scheduler_G,
model_D, model_D,
optimizer_D, optimizer_D,
scheduler_D,
global_step, global_step,
epoch, epoch,
OUT_PATH, OUT_PATH,
@ -434,6 +433,7 @@ def main(args): # pylint: disable=redefined-outer-name
# setup audio processor # setup audio processor
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
# DISTRUBUTED # DISTRUBUTED
# if num_gpus > 1: # if num_gpus > 1:
# init_distributed(args.rank, num_gpus, args.group_id, # init_distributed(args.rank, num_gpus, args.group_id,
@ -449,6 +449,12 @@ def main(args): # pylint: disable=redefined-outer-name
lr=c.lr_disc, lr=c.lr_disc,
weight_decay=0) weight_decay=0)
# schedulers
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
# setup criterion # setup criterion
criterion_gen = GeneratorLoss(c) criterion_gen = GeneratorLoss(c)
criterion_disc = DiscriminatorLoss(c) criterion_disc = DiscriminatorLoss(c)
@ -456,12 +462,26 @@ def main(args): # pylint: disable=redefined-outer-name
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu') checkpoint = torch.load(args.restore_path, map_location='cpu')
try: try:
print(" > Restoring Generator Model...")
model_gen.load_state_dict(checkpoint['model']) model_gen.load_state_dict(checkpoint['model'])
print(" > Restoring Generator Optimizer...")
optimizer_gen.load_state_dict(checkpoint['optimizer']) optimizer_gen.load_state_dict(checkpoint['optimizer'])
print(" > Restoring Discriminator Model...")
model_disc.load_state_dict(checkpoint['model_disc']) model_disc.load_state_dict(checkpoint['model_disc'])
print(" > Restoring Discriminator Optimizer...")
optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
if 'scheduler' in checkpoint:
print(" > Restoring Generator LR Scheduler...")
scheduler_gen.load_state_dict(checkpoint['scheduler'])
# NOTE: Not sure if necessary
scheduler_gen.optimizer = optimizer_gen
if 'scheduler_disc' in checkpoint:
print(" > Restoring Discriminator LR Scheduler...")
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
scheduler_disc.optimizer = optimizer_disc
except RuntimeError: except RuntimeError:
print(" > Partial model initialization.") # retore only matching layers.
print(" > Partial model initialization...")
model_dict = model_gen.state_dict() model_dict = model_gen.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c) model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_gen.load_state_dict(model_dict) model_gen.load_state_dict(model_dict)
@ -494,16 +514,6 @@ def main(args): # pylint: disable=redefined-outer-name
# if num_gpus > 1: # if num_gpus > 1:
# model = apply_gradient_allreduce(model) # model = apply_gradient_allreduce(model)
if c.noam_schedule:
scheduler_gen = NoamLR(optimizer_gen,
warmup_steps=c.warmup_steps_gen,
last_epoch=args.restore_step - 1)
scheduler_disc = NoamLR(optimizer_disc,
warmup_steps=c.warmup_steps_gen,
last_epoch=args.restore_step - 1)
else:
scheduler_gen, scheduler_disc = None, None
num_params = count_parameters(model_gen) num_params = count_parameters(model_gen)
print(" > Generator has {} parameters".format(num_params), flush=True) print(" > Generator has {} parameters".format(num_params), flush=True)
num_params = count_parameters(model_disc) num_params = count_parameters(model_disc)
@ -526,9 +536,11 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = save_best_model(target_loss, best_loss = save_best_model(target_loss,
best_loss, best_loss,
model_gen, model_gen,
scheduler_gen,
optimizer_gen, optimizer_gen,
model_disc, model_disc,
optimizer_disc, optimizer_disc,
scheduler_disc,
global_step, global_step,
epoch, epoch,
OUT_PATH, OUT_PATH,