From 49f61d0b9efaf270c569679b448027481490fbdb Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 26 Jan 2018 02:07:07 -0800 Subject: [PATCH] Checkpoint fix --- config.json | 15 ++++++--------- train.py | 32 +++++++++----------------------- utils/generic_utils.py | 1 + 3 files changed, 16 insertions(+), 32 deletions(-) diff --git a/config.json b/config.json index d94f8851..a850c2eb 100644 --- a/config.json +++ b/config.json @@ -9,21 +9,18 @@ "ref_level_db": 20, "hidden_size": 128, "embedding_size": 256, + "text_cleaner": "english_cleaners", - "epochs": 10000, + "epochs": 200, "lr": 0.01, - "decay_step": [500000, 1000000, 2000000], - "batch_size": 128, - "max_iters": 200, + "lr_patience": 2, + "lr_decay": 0.5, + "batch_size": 256, "griffinf_lim_iters": 60, "power": 1.5, "r": 5, - "log_step": 100, - "save_step": 2000, - - "text_cleaner": "english_cleaners", - + "save_step": 1, "data_path": "/data/shared/KeithIto/LJSpeech-1.0", "output_path": "result", "log_dir": "/home/erogol/projects/TTS/logs/" diff --git a/train.py b/train.py index d325a54f..737fe43b 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ import os import sys import time +import datetime import shutil import torch import signal @@ -13,6 +14,7 @@ import torch.nn as nn from torch import optim from torch.autograd import Variable from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import ReduceLROnPlateau from tensorboardX import SummaryWriter from utils.generic_utils import (Progbar, remove_experiment_folder, @@ -97,12 +99,15 @@ def main(args): n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq) + lr_scheduler = ReduceLROnPlateau(optimizer, factor=c.lr_decay, + patience=c.lr_patience, verbose=True) + for epoch in range(c.epochs): dataloader = DataLoader(dataset, batch_size=c.batch_size, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=32) - print("\n | > Epoch {}".format(epoch)) + print("\n | > Epoch {}/{}".format(epoch, c.epochs)) progbar = Progbar(len(dataset) / c.batch_size) for i, data in enumerate(dataloader): @@ -162,7 +167,7 @@ def main(args): optimizer.step() time_per_step = time.time() - start_time - progbar.update(i, values=[('total_loss', loss.data[0]), + progbar.update(i+1, values=[('total_loss', loss.data[0]), ('linear_loss', linear_loss.data[0]), ('mel_loss', mel_loss.data[0])]) @@ -181,27 +186,8 @@ def main(args): 'mel_loss': mel_loss.data[0], 'date': datetime.date.today().strftime("%B %d, %Y")}, checkpoint_path) - print(" > Checkpoint is saved : {}".format(checkpoint_path)) - - if current_step in c.decay_step: - optimizer = adjust_learning_rate(optimizer, current_step) - - -def adjust_learning_rate(optimizer, step): - """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" - if step == 500000: - for param_group in optimizer.param_groups: - param_group['lr'] = 0.0005 - - elif step == 1000000: - for param_group in optimizer.param_groups: - param_group['lr'] = 0.0003 - - elif step == 2000000: - for param_group in optimizer.param_groups: - param_group['lr'] = 0.0001 - - return optimizer + print("\n | > Checkpoint is saved : {}".format(checkpoint_path)) + lr_scheduler.step(loss.data[0]) if __name__ == '__main__': diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 4ca9e632..aa329fb7 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -5,6 +5,7 @@ import time import shutil import datetime import json +import torch import numpy as np