mirror of https://github.com/coqui-ai/TTS.git
Rename LR scheduler
This commit is contained in:
parent
c98631fe36
commit
bb2a88a984
4
train.py
4
train.py
|
@ -16,7 +16,7 @@ from tensorboardX import SummaryWriter
|
|||
from utils.generic_utils import (
|
||||
remove_experiment_folder, create_experiment_folder, save_checkpoint,
|
||||
save_best_model, load_config, lr_decay, count_parameters, check_update,
|
||||
get_commit_hash, sequence_mask, AnnealLR)
|
||||
get_commit_hash, sequence_mask, NoamLR)
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from models.tacotron import Tacotron
|
||||
from layers.losses import L1LossMasked
|
||||
|
@ -444,7 +444,7 @@ def main(args):
|
|||
criterion_st.cuda()
|
||||
|
||||
if c.lr_decay:
|
||||
scheduler = AnnealLR(
|
||||
scheduler = NoamLR(
|
||||
optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
|
|
|
@ -148,10 +148,10 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
|||
return lr
|
||||
|
||||
|
||||
class AnnealLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||
self.warmup_steps = float(warmup_steps)
|
||||
super(AnnealLR, self).__init__(optimizer, last_epoch)
|
||||
super(NoamLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
step = max(self.last_epoch, 1)
|
||||
|
|
Loading…
Reference in New Issue