mirror of https://github.com/coqui-ai/TTS.git
Make lr scheduler configurable
This commit is contained in:
parent
ec5de131fe
commit
f95e8413ed
10
train.py
10
train.py
|
@ -21,7 +21,7 @@ from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
from torch.optim.lr_scheduler import StepLR
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
torch.set_num_threads(4)
|
torch.set_num_threads(4)
|
||||||
|
@ -337,6 +337,14 @@ def main(args):
|
||||||
audio = importlib.import_module('utils.' + c.audio_processor)
|
audio = importlib.import_module('utils.' + c.audio_processor)
|
||||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||||
|
|
||||||
|
print(" > LR scheduler: {} ", c.lr_scheduler)
|
||||||
|
try:
|
||||||
|
scheduler = importlib.import_module('torch.optim.lr_scheduler')
|
||||||
|
scheduler = getattr(scheduler, c.lr_scheduler)
|
||||||
|
except:
|
||||||
|
scheduler = importlib.import_module('utils.generic_utils')
|
||||||
|
scheduler = getattr(scheduler, c.lr_scheduler)
|
||||||
|
|
||||||
ap = AudioProcessor(
|
ap = AudioProcessor(
|
||||||
sample_rate=c.sample_rate,
|
sample_rate=c.sample_rate,
|
||||||
num_mels=c.num_mels,
|
num_mels=c.num_mels,
|
||||||
|
|
|
@ -142,6 +142,20 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
class AnnealLR(torch.optim.lr_scheduler._LRScheduler):
|
||||||
|
def __init__(self, optimizer, warmup_steps=0.1):
|
||||||
|
self.warmup_steps = float(warmup_steps)
|
||||||
|
super(AnnealLR, self).__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
return [
|
||||||
|
base_lr * self.warmup_steps**0.5 * torch.min([
|
||||||
|
self.last_epoch * self.warmup_steps**-1.5, self.last_epoch**
|
||||||
|
-0.5
|
||||||
|
]) for base_lr in self.base_lrs
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def mk_decay(init_mk, max_epoch, n_epoch):
|
def mk_decay(init_mk, max_epoch, n_epoch):
|
||||||
return init_mk * ((max_epoch - n_epoch) / max_epoch)
|
return init_mk * ((max_epoch - n_epoch) / max_epoch)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue