mirror of https://github.com/coqui-ai/TTS.git
add lr schedulers for generator and discriminator
This commit is contained in:
parent
8552c1d991
commit
f0144bfcba
|
@ -107,10 +107,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# get current learning rates
|
|
||||||
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
|
|
||||||
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
|
|
||||||
|
|
||||||
##############################
|
##############################
|
||||||
# GENERATOR
|
# GENERATOR
|
||||||
##############################
|
##############################
|
||||||
|
@ -166,9 +162,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
||||||
c.gen_clip_grad)
|
c.gen_clip_grad)
|
||||||
optimizer_G.step()
|
optimizer_G.step()
|
||||||
|
if scheduler_G is not None:
|
||||||
# setup lr
|
|
||||||
if c.noam_schedule:
|
|
||||||
scheduler_G.step()
|
scheduler_G.step()
|
||||||
|
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
|
@ -221,9 +215,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
||||||
c.disc_clip_grad)
|
c.disc_clip_grad)
|
||||||
optimizer_D.step()
|
optimizer_D.step()
|
||||||
|
if c.scheduler_D is not None:
|
||||||
# setup lr
|
|
||||||
if c.noam_schedule:
|
|
||||||
scheduler_D.step()
|
scheduler_D.step()
|
||||||
|
|
||||||
for key, value in loss_D_dict.items():
|
for key, value in loss_D_dict.items():
|
||||||
|
@ -232,6 +224,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
|
# get current learning rates
|
||||||
|
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
|
||||||
|
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
|
||||||
|
|
||||||
# update avg stats
|
# update avg stats
|
||||||
update_train_values = dict()
|
update_train_values = dict()
|
||||||
for key, value in loss_dict.items():
|
for key, value in loss_dict.items():
|
||||||
|
@ -244,7 +240,8 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
if global_step % c.print_step == 0:
|
if global_step % c.print_step == 0:
|
||||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||||
step_time, loader_time, current_lr_G,
|
step_time, loader_time, current_lr_G,
|
||||||
loss_dict, keep_avg.avg_values)
|
current_lr_D, loss_dict,
|
||||||
|
keep_avg.avg_values)
|
||||||
|
|
||||||
# plot step stats
|
# plot step stats
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
|
@ -262,8 +259,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
# save model
|
# save model
|
||||||
save_checkpoint(model_G,
|
save_checkpoint(model_G,
|
||||||
optimizer_G,
|
optimizer_G,
|
||||||
|
scheduler_G,
|
||||||
model_D,
|
model_D,
|
||||||
optimizer_D,
|
optimizer_D,
|
||||||
|
scheduler_D,
|
||||||
global_step,
|
global_step,
|
||||||
epoch,
|
epoch,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
|
@ -434,6 +433,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
# setup audio processor
|
# setup audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
# if num_gpus > 1:
|
# if num_gpus > 1:
|
||||||
# init_distributed(args.rank, num_gpus, args.group_id,
|
# init_distributed(args.rank, num_gpus, args.group_id,
|
||||||
|
@ -449,6 +449,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
lr=c.lr_disc,
|
lr=c.lr_disc,
|
||||||
weight_decay=0)
|
weight_decay=0)
|
||||||
|
|
||||||
|
# schedulers
|
||||||
|
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
|
||||||
|
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
|
||||||
|
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
|
||||||
|
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
criterion_gen = GeneratorLoss(c)
|
criterion_gen = GeneratorLoss(c)
|
||||||
criterion_disc = DiscriminatorLoss(c)
|
criterion_disc = DiscriminatorLoss(c)
|
||||||
|
@ -456,12 +462,26 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||||
try:
|
try:
|
||||||
|
print(" > Restoring Generator Model...")
|
||||||
model_gen.load_state_dict(checkpoint['model'])
|
model_gen.load_state_dict(checkpoint['model'])
|
||||||
|
print(" > Restoring Generator Optimizer...")
|
||||||
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
||||||
|
print(" > Restoring Discriminator Model...")
|
||||||
model_disc.load_state_dict(checkpoint['model_disc'])
|
model_disc.load_state_dict(checkpoint['model_disc'])
|
||||||
|
print(" > Restoring Discriminator Optimizer...")
|
||||||
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
||||||
|
if 'scheduler' in checkpoint:
|
||||||
|
print(" > Restoring Generator LR Scheduler...")
|
||||||
|
scheduler_gen.load_state_dict(checkpoint['scheduler'])
|
||||||
|
# NOTE: Not sure if necessary
|
||||||
|
scheduler_gen.optimizer = optimizer_gen
|
||||||
|
if 'scheduler_disc' in checkpoint:
|
||||||
|
print(" > Restoring Discriminator LR Scheduler...")
|
||||||
|
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
||||||
|
scheduler_disc.optimizer = optimizer_disc
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
print(" > Partial model initialization.")
|
# retore only matching layers.
|
||||||
|
print(" > Partial model initialization...")
|
||||||
model_dict = model_gen.state_dict()
|
model_dict = model_gen.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||||
model_gen.load_state_dict(model_dict)
|
model_gen.load_state_dict(model_dict)
|
||||||
|
@ -494,16 +514,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# if num_gpus > 1:
|
# if num_gpus > 1:
|
||||||
# model = apply_gradient_allreduce(model)
|
# model = apply_gradient_allreduce(model)
|
||||||
|
|
||||||
if c.noam_schedule:
|
|
||||||
scheduler_gen = NoamLR(optimizer_gen,
|
|
||||||
warmup_steps=c.warmup_steps_gen,
|
|
||||||
last_epoch=args.restore_step - 1)
|
|
||||||
scheduler_disc = NoamLR(optimizer_disc,
|
|
||||||
warmup_steps=c.warmup_steps_gen,
|
|
||||||
last_epoch=args.restore_step - 1)
|
|
||||||
else:
|
|
||||||
scheduler_gen, scheduler_disc = None, None
|
|
||||||
|
|
||||||
num_params = count_parameters(model_gen)
|
num_params = count_parameters(model_gen)
|
||||||
print(" > Generator has {} parameters".format(num_params), flush=True)
|
print(" > Generator has {} parameters".format(num_params), flush=True)
|
||||||
num_params = count_parameters(model_disc)
|
num_params = count_parameters(model_disc)
|
||||||
|
@ -526,9 +536,11 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
best_loss = save_best_model(target_loss,
|
best_loss = save_best_model(target_loss,
|
||||||
best_loss,
|
best_loss,
|
||||||
model_gen,
|
model_gen,
|
||||||
|
scheduler_gen,
|
||||||
optimizer_gen,
|
optimizer_gen,
|
||||||
model_disc,
|
model_disc,
|
||||||
optimizer_disc,
|
optimizer_disc,
|
||||||
|
scheduler_disc,
|
||||||
global_step,
|
global_step,
|
||||||
epoch,
|
epoch,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
|
|
Loading…
Reference in New Issue