Move things into main and set thread

This commit is contained in:
Eren G 2018-07-17 15:59:31 +02:00
parent bc284dc3d7
commit 766be91a09
1 changed files with 23 additions and 21 deletions

View File

@ -27,30 +27,11 @@ from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron from models.tacotron import Tacotron
from layers.losses import L1LossMasked from layers.losses import L1LossMasked
torch.manual_seed(1) torch.manual_seed(1)
torch.set_num_threads(4)
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
parser = argparse.ArgumentParser()
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
parser.add_argument('--debug', type=bool, default=False,
help='do not ask for git has before run.')
args = parser.parse_args()
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch): def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
model = model.train() model = model.train()
@ -440,6 +421,27 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
parser.add_argument('--debug', type=bool, default=False,
help='do not ask for git has before run.')
args = parser.parse_args()
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
try: try:
main(args) main(args)
except KeyboardInterrupt: except KeyboardInterrupt: