add git branch and restore_path to copied config file for each run

This commit is contained in:
Eren Golge 2019-03-29 17:01:08 +01:00
parent 1ed4978e69
commit 6edd8bc6dd
2 changed files with 25 additions and 7 deletions

View File

@ -19,11 +19,11 @@ from distribute import (DistributedSampler, apply_gradient_allreduce,
from layers.losses import L1LossMasked, MSELossMasked
from utils.audio import AudioProcessor
from utils.generic_utils import (NoamLR, check_update, count_parameters,
create_experiment_folder, get_commit_hash,
create_experiment_folder, get_git_branch,
load_config, lr_decay,
remove_experiment_folder, save_best_model,
save_checkpoint, sequence_mask, weight_decay,
set_init_dict)
set_init_dict, copy_config_file)
from utils.logger import Logger
from utils.synthesis import synthesis
from utils.text.symbols import phonemes, symbols
@ -511,8 +511,11 @@ if __name__ == '__main__':
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
shutil.copyfile(args.config_path, os.path.join(OUT_PATH,
'config.json'))
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_config_file(args.config_path, os.path.join(OUT_PATH, 'config.json'), new_fields)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)

View File

@ -31,6 +31,12 @@ def load_config(config_path):
return config
def get_git_branch():
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n") if line.startswith("*"))
return current.replace("* ", "")
def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
# try:
@ -71,10 +77,19 @@ def remove_experiment_folder(experiment_path):
print(" ! Run is kept in {}".format(experiment_path))
def copy_config_file(config_file, path):
def copy_config_file(config_file, out_path, new_fields):
config_name = os.path.basename(config_file)
out_path = os.path.join(path, config_name)
shutil.copyfile(config_file, out_path)
config_lines = open(config_file, "r").readlines()
# add extra information fields
for key, value in new_fields.items():
if type(value) == str:
new_line = '"{}":"{}",\n'.format(key, value)
else:
new_line = '"{}":{},\n'.format(key, value)
config_lines.insert(1, new_line)
config_out_file = open(out_path, "w")
config_out_file.writelines(config_lines)
config_out_file.close()
def _trim_model_state_dict(state_dict):