Change config to json

This commit is contained in:
Eren Golge 2018-01-22 08:20:20 -08:00
parent fd18e1cf34
commit 72e1357c80
4 changed files with 42 additions and 59 deletions

View File

@ -1,9 +1,11 @@
import os import os
import sys import sys
import time import time
import shutil
import torch import torch
import signal import signal
import argparse import argparse
import importlib
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
@ -11,37 +13,43 @@ from torch import optim
from torch.autograd import Variable from torch.autograd import Variable
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import train_config as c
from utils.generic_utils import (Progbar, remove_experiment_folder, from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint) create_experiment_folder, save_checkpoint,
load_config)
from utils.model import get_param_size from utils.model import get_param_size
from datasets.LJSpeech import LJSpeechDataset from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron from models.tacotron import Tacotron
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH)
def signal_handler(signal, frame):
print(" !! Pressed Ctrl+C !!")
remove_experiment_folder(OUT_PATH)
sys.exit(0)
def main(args): def main(args):
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# Ctrl+C handler to remove empty experiment folder
def signal_handler(signal, frame):
print(" !! Pressed Ctrl+C !!")
remove_experiment_folder(OUT_PATH)
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'), dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
os.path.join(c.data_path, 'wavs'), os.path.join(c.data_path, 'wavs'),
c.dec_out_per_step c.r
) )
model = Tacotron(c.embedding_size, model = Tacotron(c.embedding_size,
c.hidden_size, c.hidden_size,
c.num_mels, c.num_mels,
c.num_freq, c.num_freq,
c.dec_out_per_step) c.r)
if use_cuda: if use_cuda:
model = nn.DataParallel(model.cuda()) model = nn.DataParallel(model.cuda())
@ -49,7 +57,7 @@ def main(args):
try: try:
checkpoint = torch.load(os.path.join( checkpoint = torch.load(os.path.join(
c.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step))
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % args.restore_step) print("\n > Model restored from step %d\n" % args.restore_step)
@ -59,8 +67,8 @@ def main(args):
model = model.train() model = model.train()
if not os.path.exists(c.checkpoint_path): if not os.path.exists(CHECKPOINT_PATH):
os.mkdir(c.checkpoint_path) os.mkdir(CHECKPOINT_PATH)
if use_cuda: if use_cuda:
criterion = nn.L1Loss().cuda() criterion = nn.L1Loss().cuda()
@ -71,10 +79,10 @@ def main(args):
for epoch in range(c.epochs): for epoch in range(c.epochs):
dataloader = DataLoader(dataset, batch_size=args.batch_size, dataloader = DataLoader(dataset, batch_size=c.batch_size,
shuffle=True, collate_fn=dataset.collate_fn, shuffle=True, collate_fn=dataset.collate_fn,
drop_last=True, num_workers=32) drop_last=True, num_workers=32)
progbar = Progbar(len(dataset) / args.batch_size) progbar = Progbar(len(dataset) / c.batch_size)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
text_input = data[0] text_input = data[0]
@ -87,7 +95,7 @@ def main(args):
try: try:
mel_input = np.concatenate((np.zeros( mel_input = np.concatenate((np.zeros(
[args.batch_size, 1, c.num_mels], dtype=np.float32), [c.batch_size, 1, c.num_mels], dtype=np.float32),
mel_input[:, 1:, :]), axis=1) mel_input[:, 1:, :]), axis=1)
except: except:
raise TypeError("not same dimension") raise TypeError("not same dimension")
@ -175,12 +183,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--restore_step', type=int, parser.add_argument('--restore_step', type=int,
help='Global step to restore checkpoint', default=128) help='Global step to restore checkpoint', default=128)
parser.add_argument('--batch_size', type=int, parser.add_argument('--config_path', type=str,
help='Batch size', default=128)
parser.add_argument('--config', type=str,
help='path to config file for training',) help='path to config file for training',)
args = parser.parse_args() args = parser.parse_args()
signal.signal(signal.SIGINT, signal_handler)
main(args) main(args)

View File

@ -1,33 +0,0 @@
# Audio
num_mels = 80
num_freq = 1024
sample_rate = 20000
frame_length_ms = 50.
frame_shift_ms = 12.5
preemphasis = 0.97
min_level_db = -100
ref_level_db = 20
hidden_size = 128
embedding_size = 256
# training
epochs = 10000
lr = 0.001
decay_step = [500000, 1000000, 2000000]
batch_size = 128
max_iters = 200
griffin_lim_iters = 60
power = 1.5
dec_out_per_step = 5
#teacher_forcing_ratio = 1.0
# outputing
log_step = 100
save_step = 2000
# text processing
cleaners = 'english_cleaners'
# data settings
data_path = '/data/shared/KeithIto/LJSpeech-1.0/'
output_path = './result'

Binary file not shown.

View File

@ -4,9 +4,22 @@ import glob
import time import time
import shutil import shutil
import datetime import datetime
import json
import numpy as np import numpy as np
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def load_config(config_path):
config = AttrDict()
config.update(json.load(open(config_path, "r")))
return config
def create_experiment_folder(root_path): def create_experiment_folder(root_path):
""" Create a folder with the current date and time """ """ Create a folder with the current date and time """
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p") date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p")
@ -20,7 +33,7 @@ def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder""" """Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar") checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
if len(checkpoint_files) == 0: if len(checkpoint_files) < 2:
shutil.rmtree(experiment_path) shutil.rmtree(experiment_path)
print(" ! Run is removed from {}".format(experiment_path)) print(" ! Run is removed from {}".format(experiment_path))
else: else: