mirror of https://github.com/coqui-ai/TTS.git
add git branch and restore_path to copied config file for each run
This commit is contained in:
parent
1ed4978e69
commit
6edd8bc6dd
11
train.py
11
train.py
|
@ -19,11 +19,11 @@ from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
from layers.losses import L1LossMasked, MSELossMasked
|
from layers.losses import L1LossMasked, MSELossMasked
|
||||||
from utils.audio import AudioProcessor
|
from utils.audio import AudioProcessor
|
||||||
from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||||
create_experiment_folder, get_commit_hash,
|
create_experiment_folder, get_git_branch,
|
||||||
load_config, lr_decay,
|
load_config, lr_decay,
|
||||||
remove_experiment_folder, save_best_model,
|
remove_experiment_folder, save_best_model,
|
||||||
save_checkpoint, sequence_mask, weight_decay,
|
save_checkpoint, sequence_mask, weight_decay,
|
||||||
set_init_dict)
|
set_init_dict, copy_config_file)
|
||||||
from utils.logger import Logger
|
from utils.logger import Logger
|
||||||
from utils.synthesis import synthesis
|
from utils.synthesis import synthesis
|
||||||
from utils.text.symbols import phonemes, symbols
|
from utils.text.symbols import phonemes, symbols
|
||||||
|
@ -511,8 +511,11 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||||
shutil.copyfile(args.config_path, os.path.join(OUT_PATH,
|
new_fields = {}
|
||||||
'config.json'))
|
if args.restore_path:
|
||||||
|
new_fields["restore_path"] = args.restore_path
|
||||||
|
new_fields["github_branch"] = get_git_branch()
|
||||||
|
copy_config_file(args.config_path, os.path.join(OUT_PATH, 'config.json'), new_fields)
|
||||||
os.chmod(AUDIO_PATH, 0o775)
|
os.chmod(AUDIO_PATH, 0o775)
|
||||||
os.chmod(OUT_PATH, 0o775)
|
os.chmod(OUT_PATH, 0o775)
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,12 @@ def load_config(config_path):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_branch():
|
||||||
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||||
|
current = next(line for line in out.split("\n") if line.startswith("*"))
|
||||||
|
return current.replace("* ", "")
|
||||||
|
|
||||||
|
|
||||||
def get_commit_hash():
|
def get_commit_hash():
|
||||||
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
||||||
# try:
|
# try:
|
||||||
|
@ -71,10 +77,19 @@ def remove_experiment_folder(experiment_path):
|
||||||
print(" ! Run is kept in {}".format(experiment_path))
|
print(" ! Run is kept in {}".format(experiment_path))
|
||||||
|
|
||||||
|
|
||||||
def copy_config_file(config_file, path):
|
def copy_config_file(config_file, out_path, new_fields):
|
||||||
config_name = os.path.basename(config_file)
|
config_name = os.path.basename(config_file)
|
||||||
out_path = os.path.join(path, config_name)
|
config_lines = open(config_file, "r").readlines()
|
||||||
shutil.copyfile(config_file, out_path)
|
# add extra information fields
|
||||||
|
for key, value in new_fields.items():
|
||||||
|
if type(value) == str:
|
||||||
|
new_line = '"{}":"{}",\n'.format(key, value)
|
||||||
|
else:
|
||||||
|
new_line = '"{}":{},\n'.format(key, value)
|
||||||
|
config_lines.insert(1, new_line)
|
||||||
|
config_out_file = open(out_path, "w")
|
||||||
|
config_out_file.writelines(config_lines)
|
||||||
|
config_out_file.close()
|
||||||
|
|
||||||
|
|
||||||
def _trim_model_state_dict(state_dict):
|
def _trim_model_state_dict(state_dict):
|
||||||
|
|
Loading…
Reference in New Issue