add lr schedulers for generator and discriminator

This commit is contained in:
erogol 2020-06-09 23:03:37 +02:00
parent 8552c1d991
commit f0144bfcba
1 changed files with 34 additions and 22 deletions

View File

@ -107,10 +107,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
global_step += 1
# get current learning rates
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
##############################
# GENERATOR
##############################
@ -166,9 +162,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
c.gen_clip_grad)
optimizer_G.step()
# setup lr
if c.noam_schedule:
if scheduler_G is not None:
scheduler_G.step()
loss_dict = dict()
@ -221,9 +215,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
c.disc_clip_grad)
optimizer_D.step()
# setup lr
if c.noam_schedule:
if c.scheduler_D is not None:
scheduler_D.step()
for key, value in loss_D_dict.items():
@ -232,6 +224,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
step_time = time.time() - start_time
epoch_time += step_time
# get current learning rates
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
@ -244,7 +240,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
if global_step % c.print_step == 0:
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
step_time, loader_time, current_lr_G,
loss_dict, keep_avg.avg_values)
current_lr_D, loss_dict,
keep_avg.avg_values)
# plot step stats
if global_step % 10 == 0:
@ -262,8 +259,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# save model
save_checkpoint(model_G,
optimizer_G,
scheduler_G,
model_D,
optimizer_D,
scheduler_D,
global_step,
epoch,
OUT_PATH,
@ -434,6 +433,7 @@ def main(args): # pylint: disable=redefined-outer-name
# setup audio processor
ap = AudioProcessor(**c.audio)
# DISTRUBUTED
# if num_gpus > 1:
# init_distributed(args.rank, num_gpus, args.group_id,
@ -449,6 +449,12 @@ def main(args): # pylint: disable=redefined-outer-name
lr=c.lr_disc,
weight_decay=0)
# schedulers
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
# setup criterion
criterion_gen = GeneratorLoss(c)
criterion_disc = DiscriminatorLoss(c)
@ -456,12 +462,26 @@ def main(args): # pylint: disable=redefined-outer-name
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
print(" > Restoring Generator Model...")
model_gen.load_state_dict(checkpoint['model'])
print(" > Restoring Generator Optimizer...")
optimizer_gen.load_state_dict(checkpoint['optimizer'])
print(" > Restoring Discriminator Model...")
model_disc.load_state_dict(checkpoint['model_disc'])
print(" > Restoring Discriminator Optimizer...")
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
if 'scheduler' in checkpoint:
print(" > Restoring Generator LR Scheduler...")
scheduler_gen.load_state_dict(checkpoint['scheduler'])
# NOTE: Not sure if necessary
scheduler_gen.optimizer = optimizer_gen
if 'scheduler_disc' in checkpoint:
print(" > Restoring Discriminator LR Scheduler...")
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
scheduler_disc.optimizer = optimizer_disc
except RuntimeError:
print(" > Partial model initialization.")
# retore only matching layers.
print(" > Partial model initialization...")
model_dict = model_gen.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_gen.load_state_dict(model_dict)
@ -494,16 +514,6 @@ def main(args): # pylint: disable=redefined-outer-name
# if num_gpus > 1:
# model = apply_gradient_allreduce(model)
if c.noam_schedule:
scheduler_gen = NoamLR(optimizer_gen,
warmup_steps=c.warmup_steps_gen,
last_epoch=args.restore_step - 1)
scheduler_disc = NoamLR(optimizer_disc,
warmup_steps=c.warmup_steps_gen,
last_epoch=args.restore_step - 1)
else:
scheduler_gen, scheduler_disc = None, None
num_params = count_parameters(model_gen)
print(" > Generator has {} parameters".format(num_params), flush=True)
num_params = count_parameters(model_disc)
@ -526,9 +536,11 @@ def main(args): # pylint: disable=redefined-outer-name
best_loss = save_best_model(target_loss,
best_loss,
model_gen,
scheduler_gen,
optimizer_gen,
model_disc,
optimizer_disc,
scheduler_disc,
global_step,
epoch,
OUT_PATH,