mirror of https://github.com/coqui-ai/TTS.git
add lr schedulers for generator and discriminator
This commit is contained in:
parent
8552c1d991
commit
f0144bfcba
|
@ -107,10 +107,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
|
||||
global_step += 1
|
||||
|
||||
# get current learning rates
|
||||
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
|
||||
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
|
||||
|
||||
##############################
|
||||
# GENERATOR
|
||||
##############################
|
||||
|
@ -166,9 +162,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
||||
c.gen_clip_grad)
|
||||
optimizer_G.step()
|
||||
|
||||
# setup lr
|
||||
if c.noam_schedule:
|
||||
if scheduler_G is not None:
|
||||
scheduler_G.step()
|
||||
|
||||
loss_dict = dict()
|
||||
|
@ -221,9 +215,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
||||
c.disc_clip_grad)
|
||||
optimizer_D.step()
|
||||
|
||||
# setup lr
|
||||
if c.noam_schedule:
|
||||
if c.scheduler_D is not None:
|
||||
scheduler_D.step()
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
|
@ -232,6 +224,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# get current learning rates
|
||||
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
|
||||
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
|
@ -244,7 +240,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
if global_step % c.print_step == 0:
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
step_time, loader_time, current_lr_G,
|
||||
loss_dict, keep_avg.avg_values)
|
||||
current_lr_D, loss_dict,
|
||||
keep_avg.avg_values)
|
||||
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
|
@ -262,8 +259,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
# save model
|
||||
save_checkpoint(model_G,
|
||||
optimizer_G,
|
||||
scheduler_G,
|
||||
model_D,
|
||||
optimizer_D,
|
||||
scheduler_D,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
|
@ -434,6 +433,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
# DISTRUBUTED
|
||||
# if num_gpus > 1:
|
||||
# init_distributed(args.rank, num_gpus, args.group_id,
|
||||
|
@ -449,6 +449,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
lr=c.lr_disc,
|
||||
weight_decay=0)
|
||||
|
||||
# schedulers
|
||||
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
|
||||
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
|
||||
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
|
||||
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
|
||||
|
||||
# setup criterion
|
||||
criterion_gen = GeneratorLoss(c)
|
||||
criterion_disc = DiscriminatorLoss(c)
|
||||
|
@ -456,12 +462,26 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
try:
|
||||
print(" > Restoring Generator Model...")
|
||||
model_gen.load_state_dict(checkpoint['model'])
|
||||
print(" > Restoring Generator Optimizer...")
|
||||
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
||||
print(" > Restoring Discriminator Model...")
|
||||
model_disc.load_state_dict(checkpoint['model_disc'])
|
||||
print(" > Restoring Discriminator Optimizer...")
|
||||
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
||||
if 'scheduler' in checkpoint:
|
||||
print(" > Restoring Generator LR Scheduler...")
|
||||
scheduler_gen.load_state_dict(checkpoint['scheduler'])
|
||||
# NOTE: Not sure if necessary
|
||||
scheduler_gen.optimizer = optimizer_gen
|
||||
if 'scheduler_disc' in checkpoint:
|
||||
print(" > Restoring Discriminator LR Scheduler...")
|
||||
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
||||
scheduler_disc.optimizer = optimizer_disc
|
||||
except RuntimeError:
|
||||
print(" > Partial model initialization.")
|
||||
# retore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model_gen.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_gen.load_state_dict(model_dict)
|
||||
|
@ -494,16 +514,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# if num_gpus > 1:
|
||||
# model = apply_gradient_allreduce(model)
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler_gen = NoamLR(optimizer_gen,
|
||||
warmup_steps=c.warmup_steps_gen,
|
||||
last_epoch=args.restore_step - 1)
|
||||
scheduler_disc = NoamLR(optimizer_disc,
|
||||
warmup_steps=c.warmup_steps_gen,
|
||||
last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler_gen, scheduler_disc = None, None
|
||||
|
||||
num_params = count_parameters(model_gen)
|
||||
print(" > Generator has {} parameters".format(num_params), flush=True)
|
||||
num_params = count_parameters(model_disc)
|
||||
|
@ -526,9 +536,11 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
best_loss = save_best_model(target_loss,
|
||||
best_loss,
|
||||
model_gen,
|
||||
scheduler_gen,
|
||||
optimizer_gen,
|
||||
model_disc,
|
||||
optimizer_disc,
|
||||
scheduler_disc,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
|
|
Loading…
Reference in New Issue