Move things into main and set thread

This commit is contained in:
Eren G 2018-07-17 15:59:31 +02:00
parent 90d7e885e7
commit d4f1ccd3ed
1 changed files with 23 additions and 21 deletions

View File

@ -27,30 +27,11 @@ from datasets.LJSpeech import LJSpeechDataset
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
torch.manual_seed(1)
torch.set_num_threads(4)
use_cuda = torch.cuda.is_available()
parser = argparse.ArgumentParser()
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
parser.add_argument('--debug', type=bool, default=False,
help='do not ask for git has before run.')
args = parser.parse_args()
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
model = model.train()
@ -440,6 +421,27 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--restore_path', type=str,
help='Folder path to checkpoints', default=0)
parser.add_argument('--config_path', type=str,
help='path to config file for training',)
parser.add_argument('--debug', type=bool, default=False,
help='do not ask for git has before run.')
args = parser.parse_args()
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = os.path.join(_, c.output_path)
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
CHECKPOINT_PATH = os.path.join(OUT_PATH, 'checkpoints')
shutil.copyfile(args.config_path, os.path.join(OUT_PATH, 'config.json'))
# setup tensorboard
LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
try:
main(args)
except KeyboardInterrupt: