styling formatting.py

This commit is contained in:
Eren Gölge 2021-05-25 14:41:13 +02:00
parent 120ea679f9
commit 534401377d
1 changed files with 26 additions and 13 deletions

View File

@ -29,16 +29,16 @@ def init_arguments(argv):
parser.add_argument( parser.add_argument(
"--continue_path", "--continue_path",
type=str, type=str,
help=( help=("Training output folder to continue training. Used to continue "
"Training output folder to continue training. Used to continue " "a training. If it is used, 'config_path' is ignored."),
"a training. If it is used, 'config_path' is ignored."
),
default="", default="",
required="--config_path" not in argv, required="--config_path" not in argv,
) )
parser.add_argument( parser.add_argument(
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="" "--restore_path",
) type=str,
help="Model file to be restored. Use to finetune a model.",
default="")
parser.add_argument( parser.add_argument(
"--best_path", "--best_path",
type=str, type=str,
@ -48,12 +48,23 @@ def init_arguments(argv):
), ),
default="", default="",
) )
parser.add_argument("--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in argv)
parser.add_argument("--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.")
parser.add_argument( parser.add_argument(
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv "--rank",
) type=int,
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.") default=0,
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.") help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.") parser.add_argument("--group_id",
type=str,
default="",
help="DISTRIBUTED: process group id.")
return parser return parser
@ -148,7 +159,8 @@ def process_args(args):
print(" > Mixed precision mode is ON") print(" > Mixed precision mode is ON")
experiment_path = args.continue_path experiment_path = args.continue_path
if not experiment_path: if not experiment_path:
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug) experiment_path = create_experiment_folder(config.output_path,
config.run_name, args.debug)
audio_path = os.path.join(experiment_path, "test_audios") audio_path = os.path.join(experiment_path, "test_audios")
# setup rank 0 process in distributed training # setup rank 0 process in distributed training
tb_logger = None tb_logger = None
@ -169,7 +181,8 @@ def process_args(args):
os.chmod(experiment_path, 0o775) os.chmod(experiment_path, 0o775)
tb_logger = TensorboardLogger(experiment_path, model_name=config.model) tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
# write model desc to tensorboard # write model desc to tensorboard
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0) tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>",
0)
c_logger = ConsoleLogger() c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, tb_logger