diff --git a/train.py b/train.py index f575bc48..6ca1f1fe 100644 --- a/train.py +++ b/train.py @@ -1,9 +1,11 @@ import os import sys import time +import shutil import torch import signal import argparse +import importlib import numpy as np import torch.nn as nn @@ -11,37 +13,43 @@ from torch import optim from torch.autograd import Variable from torch.utils.data import DataLoader -import train_config as c from utils.generic_utils import (Progbar, remove_experiment_folder, - create_experiment_folder, save_checkpoint) + create_experiment_folder, save_checkpoint, + load_config) from utils.model import get_param_size from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron use_cuda = torch.cuda.is_available() -_ = os.path.dirname(os.path.realpath(__file__)) -OUT_PATH = os.path.join(_, c.output_path) -OUT_PATH = create_experiment_folder(OUT_PATH) - -def signal_handler(signal, frame): - print(" !! Pressed Ctrl+C !!") - remove_experiment_folder(OUT_PATH) - sys.exit(0) - def main(args): + # setup output paths and read configs + c = load_config(args.config_path) + _ = os.path.dirname(os.path.realpath(__file__)) + OUT_PATH = os.path.join(_, c.output_path) + OUT_PATH = create_experiment_folder(OUT_PATH) + CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') + shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) + + # Ctrl+C handler to remove empty experiment folder + def signal_handler(signal, frame): + print(" !! Pressed Ctrl+C !!") + remove_experiment_folder(OUT_PATH) + sys.exit(0) + signal.signal(signal.SIGINT, signal_handler) + dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), os.path.join(c.data_path, 'wavs'), - c.dec_out_per_step + c.r ) model = Tacotron(c.embedding_size, c.hidden_size, c.num_mels, c.num_freq, - c.dec_out_per_step) + c.r) if use_cuda: model = nn.DataParallel(model.cuda()) @@ -49,7 +57,7 @@ def main(args): try: checkpoint = torch.load(os.path.join( - c.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) + CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n > Model restored from step %d\n" % args.restore_step) @@ -59,8 +67,8 @@ def main(args): model = model.train() - if not os.path.exists(c.checkpoint_path): - os.mkdir(c.checkpoint_path) + if not os.path.exists(CHECKPOINT_PATH): + os.mkdir(CHECKPOINT_PATH) if use_cuda: criterion = nn.L1Loss().cuda() @@ -71,10 +79,10 @@ def main(args): for epoch in range(c.epochs): - dataloader = DataLoader(dataset, batch_size=args.batch_size, + dataloader = DataLoader(dataset, batch_size=c.batch_size, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=32) - progbar = Progbar(len(dataset) / args.batch_size) + progbar = Progbar(len(dataset) / c.batch_size) for i, data in enumerate(dataloader): text_input = data[0] @@ -87,7 +95,7 @@ def main(args): try: mel_input = np.concatenate((np.zeros( - [args.batch_size, 1, c.num_mels], dtype=np.float32), + [c.batch_size, 1, c.num_mels], dtype=np.float32), mel_input[:, 1:, :]), axis=1) except: raise TypeError("not same dimension") @@ -175,12 +183,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--restore_step', type=int, help='Global step to restore checkpoint', default=128) - parser.add_argument('--batch_size', type=int, - help='Batch size', default=128) - parser.add_argument('--config', type=str, + parser.add_argument('--config_path', type=str, help='path to config file for training',) args = parser.parse_args() - - signal.signal(signal.SIGINT, signal_handler) - main(args) diff --git a/train_config.py b/train_config.py deleted file mode 100644 index 96a4a59d..00000000 --- a/train_config.py +++ /dev/null @@ -1,33 +0,0 @@ -# Audio -num_mels = 80 -num_freq = 1024 -sample_rate = 20000 -frame_length_ms = 50. -frame_shift_ms = 12.5 -preemphasis = 0.97 -min_level_db = -100 -ref_level_db = 20 -hidden_size = 128 -embedding_size = 256 - -# training -epochs = 10000 -lr = 0.001 -decay_step = [500000, 1000000, 2000000] -batch_size = 128 -max_iters = 200 -griffin_lim_iters = 60 -power = 1.5 -dec_out_per_step = 5 -#teacher_forcing_ratio = 1.0 - -# outputing -log_step = 100 -save_step = 2000 - -# text processing -cleaners = 'english_cleaners' - -# data settings -data_path = '/data/shared/KeithIto/LJSpeech-1.0/' -output_path = './result' diff --git a/utils/.generic_utils.py.swp b/utils/.generic_utils.py.swp index 9dbbe51f..4b46ae8a 100644 Binary files a/utils/.generic_utils.py.swp and b/utils/.generic_utils.py.swp differ diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 380a0e0c..e5fc0cb4 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -4,9 +4,22 @@ import glob import time import shutil import datetime +import json import numpy as np +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def load_config(config_path): + config = AttrDict() + config.update(json.load(open(config_path, "r"))) + return config + + def create_experiment_folder(root_path): """ Create a folder with the current date and time """ date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p") @@ -20,7 +33,7 @@ def remove_experiment_folder(experiment_path): """Check folder if there is a checkpoint, otherwise remove the folder""" checkpoint_files = glob.glob(experiment_path+"/*.pth.tar") - if len(checkpoint_files) == 0: + if len(checkpoint_files) < 2: shutil.rmtree(experiment_path) print(" ! Run is removed from {}".format(experiment_path)) else: