From 72e1357c80220128cec7de9e31daf5353dc01ebb Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 22 Jan 2018 08:20:20 -0800 Subject: [PATCH] Change config to json --- train.py | 53 +++++++++++++++++++----------------- train_config.py | 33 ---------------------- utils/.generic_utils.py.swp | Bin 20480 -> 20480 bytes utils/generic_utils.py | 15 +++++++++- 4 files changed, 42 insertions(+), 59 deletions(-) delete mode 100644 train_config.py diff --git a/train.py b/train.py index f575bc48..6ca1f1fe 100644 --- a/train.py +++ b/train.py @@ -1,9 +1,11 @@ import os import sys import time +import shutil import torch import signal import argparse +import importlib import numpy as np import torch.nn as nn @@ -11,37 +13,43 @@ from torch import optim from torch.autograd import Variable from torch.utils.data import DataLoader -import train_config as c from utils.generic_utils import (Progbar, remove_experiment_folder, - create_experiment_folder, save_checkpoint) + create_experiment_folder, save_checkpoint, + load_config) from utils.model import get_param_size from datasets.LJSpeech import LJSpeechDataset from models.tacotron import Tacotron use_cuda = torch.cuda.is_available() -_ = os.path.dirname(os.path.realpath(__file__)) -OUT_PATH = os.path.join(_, c.output_path) -OUT_PATH = create_experiment_folder(OUT_PATH) - -def signal_handler(signal, frame): - print(" !! Pressed Ctrl+C !!") - remove_experiment_folder(OUT_PATH) - sys.exit(0) - def main(args): + # setup output paths and read configs + c = load_config(args.config_path) + _ = os.path.dirname(os.path.realpath(__file__)) + OUT_PATH = os.path.join(_, c.output_path) + OUT_PATH = create_experiment_folder(OUT_PATH) + CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints') + shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json')) + + # Ctrl+C handler to remove empty experiment folder + def signal_handler(signal, frame): + print(" !! Pressed Ctrl+C !!") + remove_experiment_folder(OUT_PATH) + sys.exit(0) + signal.signal(signal.SIGINT, signal_handler) + dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), os.path.join(c.data_path, 'wavs'), - c.dec_out_per_step + c.r ) model = Tacotron(c.embedding_size, c.hidden_size, c.num_mels, c.num_freq, - c.dec_out_per_step) + c.r) if use_cuda: model = nn.DataParallel(model.cuda()) @@ -49,7 +57,7 @@ def main(args): try: checkpoint = torch.load(os.path.join( - c.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) + CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n > Model restored from step %d\n" % args.restore_step) @@ -59,8 +67,8 @@ def main(args): model = model.train() - if not os.path.exists(c.checkpoint_path): - os.mkdir(c.checkpoint_path) + if not os.path.exists(CHECKPOINT_PATH): + os.mkdir(CHECKPOINT_PATH) if use_cuda: criterion = nn.L1Loss().cuda() @@ -71,10 +79,10 @@ def main(args): for epoch in range(c.epochs): - dataloader = DataLoader(dataset, batch_size=args.batch_size, + dataloader = DataLoader(dataset, batch_size=c.batch_size, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=32) - progbar = Progbar(len(dataset) / args.batch_size) + progbar = Progbar(len(dataset) / c.batch_size) for i, data in enumerate(dataloader): text_input = data[0] @@ -87,7 +95,7 @@ def main(args): try: mel_input = np.concatenate((np.zeros( - [args.batch_size, 1, c.num_mels], dtype=np.float32), + [c.batch_size, 1, c.num_mels], dtype=np.float32), mel_input[:, 1:, :]), axis=1) except: raise TypeError("not same dimension") @@ -175,12 +183,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--restore_step', type=int, help='Global step to restore checkpoint', default=128) - parser.add_argument('--batch_size', type=int, - help='Batch size', default=128) - parser.add_argument('--config', type=str, + parser.add_argument('--config_path', type=str, help='path to config file for training',) args = parser.parse_args() - - signal.signal(signal.SIGINT, signal_handler) - main(args) diff --git a/train_config.py b/train_config.py deleted file mode 100644 index 96a4a59d..00000000 --- a/train_config.py +++ /dev/null @@ -1,33 +0,0 @@ -# Audio -num_mels = 80 -num_freq = 1024 -sample_rate = 20000 -frame_length_ms = 50. -frame_shift_ms = 12.5 -preemphasis = 0.97 -min_level_db = -100 -ref_level_db = 20 -hidden_size = 128 -embedding_size = 256 - -# training -epochs = 10000 -lr = 0.001 -decay_step = [500000, 1000000, 2000000] -batch_size = 128 -max_iters = 200 -griffin_lim_iters = 60 -power = 1.5 -dec_out_per_step = 5 -#teacher_forcing_ratio = 1.0 - -# outputing -log_step = 100 -save_step = 2000 - -# text processing -cleaners = 'english_cleaners' - -# data settings -data_path = '/data/shared/KeithIto/LJSpeech-1.0/' -output_path = './result' diff --git a/utils/.generic_utils.py.swp b/utils/.generic_utils.py.swp index 9dbbe51fbc482c58729a9f4c52deb32811f2f015..4b46ae8a4ffed90cb2142d9ba6881d2f6e011773 100644 GIT binary patch delta 2449 zcmeIz{f`qx9LMqL@ytB~5`qFt<)Cfa-g%l33Kh8nj|U-yCk0~4-gZm3E#1=Yffz#x z!DtjwSke0hy^D#Fy9gd7$3u*W8l!EUph1B8M%Qg_$VBm&1hka9EZV(s&qSaDJ!|r*Qy# zU|=#X3=!fGIo)roo%S>xge@5M;yE2w)>0TvbB=n@ZL3fL= z9#LdZNQ>%PJY^b5Yv7>Zw3~hYloj<`YI>SGz&6W=XQFMEk??C~YKwfTOEZ%ZV|_@r zy-L<>vqCA=in{Y-`_Gi1VrKmExkEGs?N(5%^$XUFug5p3`(E;lEN%M)m ziKJ+HmBNXh)XcDvTpwIwCi$g>i&`Rc%emPa(={s;HO=@f*BPDBv!O`BRIPt^gxG{z zuMS*k$s+TM9=OtZ18*pf%5OK5hGm(e>{EXc`)ecQjq;Px6A^zX6s97f5ap5QBHtrB z+fsVk)n74hzSGTUBWW<&D*bCYn*XrsO^|KIya)}t|3`9Oc{t2nKiQP{jZ5rl=&!*pcl_z8LDue zQg=a?ekE35Ii{eG@}9$sXu}ee!-01x^E)_+c0?;E`3xFWs6+)mqU1ez0|(KLM-YYF zk_(i+59hE4o3RE#Ou+Ay{{r4ZCw4%`WL#nky09N9EFv$(bQ|PyefuSQPpx9rC z_t4JcfQWGCl^`Kn;KBh|F&8Mr02EZ=5DvnI*#IGaq8p9a0~r>KxldoteuDiSsHnmw zq+%hS`3X^rY(L=?YuQA>7bCtLiaabsDEb-gB(|aq>tMp1UWnT$M-Dc^j9G6XdeMYa zaKMgL_@?7Rs6{@O>YPGIY@T`vaT!P8z-s95$y11z=s-J;p%f)ZhcBji>*uIPE|TDZ zo_V}1ije^mjQB`KFVLlBrA$VKs85m;y_6KQRqZrC@KTy%jq2{$YDp5u?^W+giwqV)P3y3Uc@5=^YNLi2K_{i7+}+hN@zivFimJTa02;AfdOB+=p-Ja z4Xx0Et$`M-Hb{bxDtT_!tK-&wozq^%t(a;`dnhLS*Y)2_^#Ogj2d%a;q7j96O!f%v zsKYK~LPj7am~I$%QHdn{V5%APpb0hpRI`YU2}4ZPfYaCsdMl#wmZ>hH0>xMWt*#wt z$64612Gc~npbfKc9aJ5BO?r21mC4|E!&lGt6rIl^!dNDivABk CSiSlH diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 380a0e0c..e5fc0cb4 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -4,9 +4,22 @@ import glob import time import shutil import datetime +import json import numpy as np +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def load_config(config_path): + config = AttrDict() + config.update(json.load(open(config_path, "r"))) + return config + + def create_experiment_folder(root_path): """ Create a folder with the current date and time """ date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p") @@ -20,7 +33,7 @@ def remove_experiment_folder(experiment_path): """Check folder if there is a checkpoint, otherwise remove the folder""" checkpoint_files = glob.glob(experiment_path+"/*.pth.tar") - if len(checkpoint_files) == 0: + if len(checkpoint_files) < 2: shutil.rmtree(experiment_path) print(" ! Run is removed from {}".format(experiment_path)) else: