mirror of https://github.com/coqui-ai/TTS.git
Change config to json
This commit is contained in:
parent
fd18e1cf34
commit
72e1357c80
53
train.py
53
train.py
|
@ -1,9 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import shutil
|
||||||
import torch
|
import torch
|
||||||
import signal
|
import signal
|
||||||
import argparse
|
import argparse
|
||||||
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -11,37 +13,43 @@ from torch import optim
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import train_config as c
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint)
|
create_experiment_folder, save_checkpoint,
|
||||||
|
load_config)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
OUT_PATH = os.path.join(_, c.output_path)
|
|
||||||
OUT_PATH = create_experiment_folder(OUT_PATH)
|
|
||||||
|
|
||||||
def signal_handler(signal, frame):
|
|
||||||
print(" !! Pressed Ctrl+C !!")
|
|
||||||
remove_experiment_folder(OUT_PATH)
|
|
||||||
sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
|
||||||
|
# setup output paths and read configs
|
||||||
|
c = load_config(args.config_path)
|
||||||
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
OUT_PATH = os.path.join(_, c.output_path)
|
||||||
|
OUT_PATH = create_experiment_folder(OUT_PATH)
|
||||||
|
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
||||||
|
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||||
|
|
||||||
|
# Ctrl+C handler to remove empty experiment folder
|
||||||
|
def signal_handler(signal, frame):
|
||||||
|
print(" !! Pressed Ctrl+C !!")
|
||||||
|
remove_experiment_folder(OUT_PATH)
|
||||||
|
sys.exit(0)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||||
os.path.join(c.data_path, 'wavs'),
|
os.path.join(c.data_path, 'wavs'),
|
||||||
c.dec_out_per_step
|
c.r
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Tacotron(c.embedding_size,
|
model = Tacotron(c.embedding_size,
|
||||||
c.hidden_size,
|
c.hidden_size,
|
||||||
c.num_mels,
|
c.num_mels,
|
||||||
c.num_freq,
|
c.num_freq,
|
||||||
c.dec_out_per_step)
|
c.r)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model = nn.DataParallel(model.cuda())
|
model = nn.DataParallel(model.cuda())
|
||||||
|
|
||||||
|
@ -49,7 +57,7 @@ def main(args):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
checkpoint = torch.load(os.path.join(
|
checkpoint = torch.load(os.path.join(
|
||||||
c.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step))
|
CHECKPOINT_PATH, 'checkpoint_%d.pth.tar' % args.restore_step))
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
print("\n > Model restored from step %d\n" % args.restore_step)
|
print("\n > Model restored from step %d\n" % args.restore_step)
|
||||||
|
@ -59,8 +67,8 @@ def main(args):
|
||||||
|
|
||||||
model = model.train()
|
model = model.train()
|
||||||
|
|
||||||
if not os.path.exists(c.checkpoint_path):
|
if not os.path.exists(CHECKPOINT_PATH):
|
||||||
os.mkdir(c.checkpoint_path)
|
os.mkdir(CHECKPOINT_PATH)
|
||||||
|
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
criterion = nn.L1Loss().cuda()
|
criterion = nn.L1Loss().cuda()
|
||||||
|
@ -71,10 +79,10 @@ def main(args):
|
||||||
|
|
||||||
for epoch in range(c.epochs):
|
for epoch in range(c.epochs):
|
||||||
|
|
||||||
dataloader = DataLoader(dataset, batch_size=args.batch_size,
|
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||||
shuffle=True, collate_fn=dataset.collate_fn,
|
shuffle=True, collate_fn=dataset.collate_fn,
|
||||||
drop_last=True, num_workers=32)
|
drop_last=True, num_workers=32)
|
||||||
progbar = Progbar(len(dataset) / args.batch_size)
|
progbar = Progbar(len(dataset) / c.batch_size)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
|
@ -87,7 +95,7 @@ def main(args):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mel_input = np.concatenate((np.zeros(
|
mel_input = np.concatenate((np.zeros(
|
||||||
[args.batch_size, 1, c.num_mels], dtype=np.float32),
|
[c.batch_size, 1, c.num_mels], dtype=np.float32),
|
||||||
mel_input[:, 1:, :]), axis=1)
|
mel_input[:, 1:, :]), axis=1)
|
||||||
except:
|
except:
|
||||||
raise TypeError("not same dimension")
|
raise TypeError("not same dimension")
|
||||||
|
@ -175,12 +183,7 @@ if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--restore_step', type=int,
|
parser.add_argument('--restore_step', type=int,
|
||||||
help='Global step to restore checkpoint', default=128)
|
help='Global step to restore checkpoint', default=128)
|
||||||
parser.add_argument('--batch_size', type=int,
|
parser.add_argument('--config_path', type=str,
|
||||||
help='Batch size', default=128)
|
|
||||||
parser.add_argument('--config', type=str,
|
|
||||||
help='path to config file for training',)
|
help='path to config file for training',)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -1,33 +0,0 @@
|
||||||
# Audio
|
|
||||||
num_mels = 80
|
|
||||||
num_freq = 1024
|
|
||||||
sample_rate = 20000
|
|
||||||
frame_length_ms = 50.
|
|
||||||
frame_shift_ms = 12.5
|
|
||||||
preemphasis = 0.97
|
|
||||||
min_level_db = -100
|
|
||||||
ref_level_db = 20
|
|
||||||
hidden_size = 128
|
|
||||||
embedding_size = 256
|
|
||||||
|
|
||||||
# training
|
|
||||||
epochs = 10000
|
|
||||||
lr = 0.001
|
|
||||||
decay_step = [500000, 1000000, 2000000]
|
|
||||||
batch_size = 128
|
|
||||||
max_iters = 200
|
|
||||||
griffin_lim_iters = 60
|
|
||||||
power = 1.5
|
|
||||||
dec_out_per_step = 5
|
|
||||||
#teacher_forcing_ratio = 1.0
|
|
||||||
|
|
||||||
# outputing
|
|
||||||
log_step = 100
|
|
||||||
save_step = 2000
|
|
||||||
|
|
||||||
# text processing
|
|
||||||
cleaners = 'english_cleaners'
|
|
||||||
|
|
||||||
# data settings
|
|
||||||
data_path = '/data/shared/KeithIto/LJSpeech-1.0/'
|
|
||||||
output_path = './result'
|
|
Binary file not shown.
|
@ -4,9 +4,22 @@ import glob
|
||||||
import time
|
import time
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path):
|
||||||
|
config = AttrDict()
|
||||||
|
config.update(json.load(open(config_path, "r")))
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def create_experiment_folder(root_path):
|
def create_experiment_folder(root_path):
|
||||||
""" Create a folder with the current date and time """
|
""" Create a folder with the current date and time """
|
||||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p")
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I:%M%p")
|
||||||
|
@ -20,7 +33,7 @@ def remove_experiment_folder(experiment_path):
|
||||||
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
||||||
|
|
||||||
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
|
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
|
||||||
if len(checkpoint_files) == 0:
|
if len(checkpoint_files) < 2:
|
||||||
shutil.rmtree(experiment_path)
|
shutil.rmtree(experiment_path)
|
||||||
print(" ! Run is removed from {}".format(experiment_path))
|
print(" ! Run is removed from {}".format(experiment_path))
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue