diff --git a/train.py b/train.py index 88886c3f..996e7b4e 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,7 @@ import signal import argparse import importlib import pickle +import traceback import numpy as np import torch.nn as nn @@ -56,13 +57,6 @@ LOG_DIR = OUT_PATH tb = SummaryWriter(LOG_DIR) -def signal_handler(signal, frame): - """Ctrl+C handler to remove empty experiment folder""" - print(" !! Pressed Ctrl+C !!") - remove_experiment_folder(OUT_PATH) - sys.exit(1) - - def train(model, criterion, data_loader, optimizer, epoch): model = model.train() epoch_time = 0 @@ -369,5 +363,16 @@ def main(args): if __name__ == '__main__': - signal.signal(signal.SIGINT, signal_handler) - main(args) + # signal.signal(signal.SIGINT, signal_handler) + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) + except Exception: + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1)