mirror of https://github.com/coqui-ai/TTS.git
Move things into main and set thread
This commit is contained in:
parent
bc284dc3d7
commit
766be91a09
44
train.py
44
train.py
|
@ -27,30 +27,11 @@ from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L1LossMasked
|
||||||
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
torch.set_num_threads(4)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--restore_path', type=str,
|
|
||||||
help='Folder path to checkpoints', default=0)
|
|
||||||
parser.add_argument('--config_path', type=str,
|
|
||||||
help='path to config file for training',)
|
|
||||||
parser.add_argument('--debug', type=bool, default=False,
|
|
||||||
help='do not ask for git has before run.')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# setup output paths and read configs
|
|
||||||
c = load_config(args.config_path)
|
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
OUT_PATH = os.path.join(_, c.output_path)
|
|
||||||
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
|
|
||||||
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
|
||||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
|
||||||
|
|
||||||
# setup tensorboard
|
|
||||||
LOG_DIR = OUT_PATH
|
|
||||||
tb = SummaryWriter(LOG_DIR)
|
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
|
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
|
||||||
model = model.train()
|
model = model.train()
|
||||||
|
@ -440,6 +421,27 @@ def main(args):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--restore_path', type=str,
|
||||||
|
help='Folder path to checkpoints', default=0)
|
||||||
|
parser.add_argument('--config_path', type=str,
|
||||||
|
help='path to config file for training',)
|
||||||
|
parser.add_argument('--debug', type=bool, default=False,
|
||||||
|
help='do not ask for git has before run.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# setup output paths and read configs
|
||||||
|
c = load_config(args.config_path)
|
||||||
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
OUT_PATH = os.path.join(_, c.output_path)
|
||||||
|
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
|
||||||
|
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
|
||||||
|
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
|
||||||
|
|
||||||
|
# setup tensorboard
|
||||||
|
LOG_DIR = OUT_PATH
|
||||||
|
tb = SummaryWriter(LOG_DIR)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
Loading…
Reference in New Issue