diff --git a/config.json b/config.json index 89a30a6b..d94f8851 100644 --- a/config.json +++ b/config.json @@ -25,5 +25,6 @@ "text_cleaner": "english_cleaners", "data_path": "/data/shared/KeithIto/LJSpeech-1.0", - "output_path": "result" + "output_path": "result", + "log_dir": "/home/erogol/projects/TTS/logs/" } diff --git a/train.py b/train.py index 04122e48..d325a54f 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,7 @@ import torch.nn as nn from torch import optim from torch.autograd import Variable from torch.utils.data import DataLoader +from tensorboardX import SummaryWriter from utils.generic_utils import (Progbar, remove_experiment_folder, create_experiment_folder, save_checkpoint, @@ -38,6 +39,10 @@ def main(args): tmp_path = os.path.join("/tmp/", file_name+'_tts') pickle.dump(c, open(tmp_path, "wb")) + # setup tensorboard + LOG_DIR = c.log_dir + tb = SummaryWriter(LOG_DIR) + # Ctrl+C handler to remove empty experiment folder def signal_handler(signal, frame): print(" !! Pressed Ctrl+C !!") @@ -78,7 +83,7 @@ def main(args): print("\n > Model restored from step %d\n" % args.restore_step) except: - print("\n > Starting a new training\n") + print("\n > Starting a new training") model = model.train() @@ -97,6 +102,7 @@ def main(args): dataloader = DataLoader(dataset, batch_size=c.batch_size, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=32) + print("\n | > Epoch {}".format(epoch)) progbar = Progbar(len(dataset) / c.batch_size) for i, data in enumerate(dataloader): @@ -160,6 +166,10 @@ def main(args): ('linear_loss', linear_loss.data[0]), ('mel_loss', mel_loss.data[0])]) + tb.add_scalar('Train/TotalLoss', loss.data[0], current_step) + tb.add_scalar('Train/LinearLoss', linear_loss.data[0], current_step) + tb.add_scalar('Train/MelLoss', mel_loss.data[0], current_step) + if current_step % c.save_step == 0: checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) checkpoint_path = os.path.join(OUT_PATH, checkpoint_path) diff --git a/utils/.generic_utils.py.swp b/utils/.generic_utils.py.swp deleted file mode 100644 index 5b476405..00000000 Binary files a/utils/.generic_utils.py.swp and /dev/null differ