Rename LR scheduler

This commit is contained in:
Eren Golge 2018-11-26 14:09:42 +01:00
parent c98631fe36
commit bb2a88a984
2 changed files with 4 additions and 4 deletions

View File

@ -16,7 +16,7 @@ from tensorboardX import SummaryWriter
from utils.generic_utils import (
remove_experiment_folder, create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay, count_parameters, check_update,
get_commit_hash, sequence_mask, AnnealLR)
get_commit_hash, sequence_mask, NoamLR)
from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
@ -444,7 +444,7 @@ def main(args):
criterion_st.cuda()
if c.lr_decay:
scheduler = AnnealLR(
scheduler = NoamLR(
optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)

View File

@ -148,10 +148,10 @@ def lr_decay(init_lr, global_step, warmup_steps):
return lr
class AnnealLR(torch.optim.lr_scheduler._LRScheduler):
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps)
super(AnnealLR, self).__init__(optimizer, last_epoch)
super(NoamLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)