From 6edd8bc6dd02149211413a6cdef203be24717520 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 29 Mar 2019 17:01:08 +0100 Subject: [PATCH] add git branch and restore_path to copied config file for each run --- train.py | 11 +++++++---- utils/generic_utils.py | 21 ++++++++++++++++++--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/train.py b/train.py index 83836bba..6a7f2ebd 100644 --- a/train.py +++ b/train.py @@ -19,11 +19,11 @@ from distribute import (DistributedSampler, apply_gradient_allreduce, from layers.losses import L1LossMasked, MSELossMasked from utils.audio import AudioProcessor from utils.generic_utils import (NoamLR, check_update, count_parameters, - create_experiment_folder, get_commit_hash, + create_experiment_folder, get_git_branch, load_config, lr_decay, remove_experiment_folder, save_best_model, save_checkpoint, sequence_mask, weight_decay, - set_init_dict) + set_init_dict, copy_config_file) from utils.logger import Logger from utils.synthesis import synthesis from utils.text.symbols import phonemes, symbols @@ -511,8 +511,11 @@ if __name__ == '__main__': if args.rank == 0: os.makedirs(AUDIO_PATH, exist_ok=True) - shutil.copyfile(args.config_path, os.path.join(OUT_PATH, - 'config.json')) + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + copy_config_file(args.config_path, os.path.join(OUT_PATH, 'config.json'), new_fields) os.chmod(AUDIO_PATH, 0o775) os.chmod(OUT_PATH, 0o775) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 3afc15f3..1a791290 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -31,6 +31,12 @@ def load_config(config_path): return config +def get_git_branch(): + out = subprocess.check_output(["git", "branch"]).decode("utf8") + current = next(line for line in out.split("\n") if line.startswith("*")) + return current.replace("* ", "") + + def get_commit_hash(): """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" # try: @@ -71,10 +77,19 @@ def remove_experiment_folder(experiment_path): print(" ! Run is kept in {}".format(experiment_path)) -def copy_config_file(config_file, path): +def copy_config_file(config_file, out_path, new_fields): config_name = os.path.basename(config_file) - out_path = os.path.join(path, config_name) - shutil.copyfile(config_file, out_path) + config_lines = open(config_file, "r").readlines() + # add extra information fields + for key, value in new_fields.items(): + if type(value) == str: + new_line = '"{}":"{}",\n'.format(key, value) + else: + new_line = '"{}":{},\n'.format(key, value) + config_lines.insert(1, new_line) + config_out_file = open(out_path, "w") + config_out_file.writelines(config_lines) + config_out_file.close() def _trim_model_state_dict(state_dict):