mirror of https://github.com/coqui-ai/TTS.git
Rename LR scheduler
This commit is contained in:
parent
c98631fe36
commit
bb2a88a984
4
train.py
4
train.py
|
@ -16,7 +16,7 @@ from tensorboardX import SummaryWriter
|
||||||
from utils.generic_utils import (
|
from utils.generic_utils import (
|
||||||
remove_experiment_folder, create_experiment_folder, save_checkpoint,
|
remove_experiment_folder, create_experiment_folder, save_checkpoint,
|
||||||
save_best_model, load_config, lr_decay, count_parameters, check_update,
|
save_best_model, load_config, lr_decay, count_parameters, check_update,
|
||||||
get_commit_hash, sequence_mask, AnnealLR)
|
get_commit_hash, sequence_mask, NoamLR)
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
@ -444,7 +444,7 @@ def main(args):
|
||||||
criterion_st.cuda()
|
criterion_st.cuda()
|
||||||
|
|
||||||
if c.lr_decay:
|
if c.lr_decay:
|
||||||
scheduler = AnnealLR(
|
scheduler = NoamLR(
|
||||||
optimizer,
|
optimizer,
|
||||||
warmup_steps=c.warmup_steps,
|
warmup_steps=c.warmup_steps,
|
||||||
last_epoch=args.restore_step - 1)
|
last_epoch=args.restore_step - 1)
|
||||||
|
|
|
@ -148,10 +148,10 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
class AnnealLR(torch.optim.lr_scheduler._LRScheduler):
|
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||||
self.warmup_steps = float(warmup_steps)
|
self.warmup_steps = float(warmup_steps)
|
||||||
super(AnnealLR, self).__init__(optimizer, last_epoch)
|
super(NoamLR, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
step = max(self.last_epoch, 1)
|
step = max(self.last_epoch, 1)
|
||||||
|
|
Loading…
Reference in New Issue