From 6818e1118513222b8b07d64eadf32ad5875d3b66 Mon Sep 17 00:00:00 2001 From: Eren Date: Sun, 12 Aug 2018 15:02:06 +0200 Subject: [PATCH] Make lr scheduler configurable --- train.py | 10 +++++++++- utils/generic_utils.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 8f908062..5ae4c03d 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ from utils.visual import plot_alignment, plot_spectrogram from models.tacotron import Tacotron from layers.losses import L1LossMasked from utils.audio import AudioProcessor -from torch.optim.lr_scheduler import StepLR + torch.manual_seed(1) torch.set_num_threads(4) @@ -337,6 +337,14 @@ def main(args): audio = importlib.import_module('utils.' + c.audio_processor) AudioProcessor = getattr(audio, 'AudioProcessor') + print(" > LR scheduler: {} ", c.lr_scheduler) + try: + scheduler = importlib.import_module('torch.optim.lr_scheduler') + scheduler = getattr(scheduler, c.lr_scheduler) + except: + scheduler = importlib.import_module('utils.generic_utils') + scheduler = getattr(scheduler, c.lr_scheduler) + ap = AudioProcessor( sample_rate=c.sample_rate, num_mels=c.num_mels, diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 18968d9d..fea8c84c 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -142,6 +142,20 @@ def lr_decay(init_lr, global_step, warmup_steps): return lr +class AnnealLR(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, warmup_steps=0.1): + self.warmup_steps = float(warmup_steps) + super(AnnealLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + return [ + base_lr * self.warmup_steps**0.5 * torch.min([ + self.last_epoch * self.warmup_steps**-1.5, self.last_epoch** + -0.5 + ]) for base_lr in self.base_lrs + ] + + def mk_decay(init_mk, max_epoch, n_epoch): return init_mk * ((max_epoch - n_epoch) / max_epoch)