Rename LR scheduler

This commit is contained in:
Eren Golge 2018-11-26 14:09:42 +01:00
parent c98631fe36
commit bb2a88a984
2 changed files with 4 additions and 4 deletions

View File

@ -16,7 +16,7 @@ from tensorboardX import SummaryWriter
from utils.generic_utils import ( from utils.generic_utils import (
remove_experiment_folder, create_experiment_folder, save_checkpoint, remove_experiment_folder, create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay, count_parameters, check_update, save_best_model, load_config, lr_decay, count_parameters, check_update,
get_commit_hash, sequence_mask, AnnealLR) get_commit_hash, sequence_mask, NoamLR)
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
@ -444,7 +444,7 @@ def main(args):
criterion_st.cuda() criterion_st.cuda()
if c.lr_decay: if c.lr_decay:
scheduler = AnnealLR( scheduler = NoamLR(
optimizer, optimizer,
warmup_steps=c.warmup_steps, warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1) last_epoch=args.restore_step - 1)

View File

@ -148,10 +148,10 @@ def lr_decay(init_lr, global_step, warmup_steps):
return lr return lr
class AnnealLR(torch.optim.lr_scheduler._LRScheduler): class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps) self.warmup_steps = float(warmup_steps)
super(AnnealLR, self).__init__(optimizer, last_epoch) super(NoamLR, self).__init__(optimizer, last_epoch)
def get_lr(self): def get_lr(self):
step = max(self.last_epoch, 1) step = max(self.last_epoch, 1)