mirror of https://github.com/coqui-ai/TTS.git
add git branch and restore_path to copied config file for each run
This commit is contained in:
parent
1ed4978e69
commit
6edd8bc6dd
11
train.py
11
train.py
|
@ -19,11 +19,11 @@ from distribute import (DistributedSampler, apply_gradient_allreduce,
|
|||
from layers.losses import L1LossMasked, MSELossMasked
|
||||
from utils.audio import AudioProcessor
|
||||
from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||
create_experiment_folder, get_commit_hash,
|
||||
create_experiment_folder, get_git_branch,
|
||||
load_config, lr_decay,
|
||||
remove_experiment_folder, save_best_model,
|
||||
save_checkpoint, sequence_mask, weight_decay,
|
||||
set_init_dict)
|
||||
set_init_dict, copy_config_file)
|
||||
from utils.logger import Logger
|
||||
from utils.synthesis import synthesis
|
||||
from utils.text.symbols import phonemes, symbols
|
||||
|
@ -511,8 +511,11 @@ if __name__ == '__main__':
|
|||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH,
|
||||
'config.json'))
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_config_file(args.config_path, os.path.join(OUT_PATH, 'config.json'), new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
|
|
|
@ -31,6 +31,12 @@ def load_config(config_path):
|
|||
return config
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
current = next(line for line in out.split("\n") if line.startswith("*"))
|
||||
return current.replace("* ", "")
|
||||
|
||||
|
||||
def get_commit_hash():
|
||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||
# try:
|
||||
|
@ -71,10 +77,19 @@ def remove_experiment_folder(experiment_path):
|
|||
print(" ! Run is kept in {}".format(experiment_path))
|
||||
|
||||
|
||||
def copy_config_file(config_file, path):
|
||||
def copy_config_file(config_file, out_path, new_fields):
|
||||
config_name = os.path.basename(config_file)
|
||||
out_path = os.path.join(path, config_name)
|
||||
shutil.copyfile(config_file, out_path)
|
||||
config_lines = open(config_file, "r").readlines()
|
||||
# add extra information fields
|
||||
for key, value in new_fields.items():
|
||||
if type(value) == str:
|
||||
new_line = '"{}":"{}",\n'.format(key, value)
|
||||
else:
|
||||
new_line = '"{}":{},\n'.format(key, value)
|
||||
config_lines.insert(1, new_line)
|
||||
config_out_file = open(out_path, "w")
|
||||
config_out_file.writelines(config_lines)
|
||||
config_out_file.close()
|
||||
|
||||
|
||||
def _trim_model_state_dict(state_dict):
|
||||
|
|
Loading…
Reference in New Issue