From 2daca15802513efce3c79a7ff02776508a13cc62 Mon Sep 17 00:00:00 2001 From: gerazov Date: Sat, 6 Feb 2021 22:25:56 +0100 Subject: [PATCH] restructured arg parsing and processing to utils --- TTS/utils/arguments.py | 207 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 TTS/utils/arguments.py diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py new file mode 100644 index 00000000..c3190e50 --- /dev/null +++ b/TTS/utils/arguments.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Argument parser for training scripts.""" + +import argparse +import re +import glob +import os + +from TTS.utils.generic_utils import ( + create_experiment_folder, get_git_branch) +from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.io import copy_model_files, load_config +from TTS.utils.tensorboard_logger import TensorboardLogger + +from TTS.tts.utils.generic_utils import check_config_tts + + +def parse_arguments(argv): + """Parse command line arguments of training scripts. + + Parameters + ---------- + argv : list + This is a list of input arguments as given by sys.argv + + Returns + ------- + argparse.Namespace + Parsed arguments. + + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--continue_path", + type=str, + help=("Training output folder to continue training. Used to continue " + "a training. If it is used, "config_path" is ignored."), + default="", + required="--config_path" not in argv) + parser.add_argument( + "--restore_path", + type=str, + help="Model file to be restored. Use to finetune a model.", + default="") + parser.add_argument( + "--config_path", + type=str, + help="Path to config file for training.", + required="--continue_path" not in argv) + parser.add_argument( + "--debug", + type=bool, + default=False, + help="Do not verify commit integrity to run training.") + parser.add_argument( + "--rank", + type=int, + default=0, + help="DISTRIBUTED: process rank for distributed training.") + parser.add_argument( + "--group_id", + type=str, + default="", + help="DISTRIBUTED: process group id.") + + return parser.parse_args() + + +def get_last_checkpoint(path): + """Get latest checkpoint from a list of filenames. + + It is based on globbing for `*.pth.tar` and the RegEx + `checkpoint_([0-9]+)`. + + Parameters + ---------- + path : list + Path to files to be compared. + + Raises + ------ + ValueError + If no checkpoint files are found. + + Returns + ------- + last_checkpoint : str + Last checkpoint filename. + + """ + last_checkpoint_num = 0 + last_checkpoint = None + filenames = glob.glob( + os.path.join(path, "/*.pth.tar")) + for filename in filenames: + try: + checkpoint_num = int( + re.search(r"checkpoint_([0-9]+)", filename).groups()[0]) + if checkpoint_num > last_checkpoint_num: + last_checkpoint_num = checkpoint_num + last_checkpoint = filename + except AttributeError: # if there's no match in the filename + pass + if last_checkpoint is None: + raise ValueError(f"No checkpoints in {path}!") + else: + return last_checkpoint + + +def process_args(args, model_type): + """Process parsed comand line arguments. + + Parameters + ---------- + args : argparse.Namespace or dict like + Parsed input arguments. + model_type : str + Model type used to check config parameters and setup the TensorBoard + logger. One of: + - tacotron + - glow_tts + - speedy_speech + - gan + - wavegrad + - wavernn + + Raises + ------ + ValueError + If `model_type` is not one of implemented choices. + + Returns + ------- + c : TTS.utils.io.AttrDict + Config paramaters. + out_path : str + Path to save models and logging. + audio_path : str + Path to save generated test audios. + c_logger : TTS.utils.console_logger.ConsoleLogger + Class that does logging to the console. + tb_logger : TTS.utils.tensorboard.TensorboardLogger + Class that does the TensorBoard loggind. + + """ + if args.continue_path != "": + args.output_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + list_of_files = glob.glob( + os.path.join(args.continue_path, "*.pth.tar") + ) # * means all if need specific format then *.csv + args.restore_path = max(list_of_files, key=os.path.getctime) + # args.restore_path = get_last_checkpoint(args.continue_path) + print(f" > Training continues for {args.restore_path}") + + # setup output paths and read configs + c = load_config(args.config_path) + + if model_type in "tacotron glow_tts speedy_speech": + model_class = "TTS" + elif model_type in "gan wavegrad wavernn": + model_class = "VOCODER" + else: + raise ValueError("model type {model_type} not recognized!") + + if model_class == "TTS": + check_config_tts(c) + elif model_class == "VOCODER": + print("Vocoder config checker not implemented, " + "skipping ...") + else: + raise ValueError(f"model type {model_type} not recognized!") + + _ = os.path.dirname(os.path.realpath(__file__)) + + if model_type in "tacotron wavegrad wavernn" and c.mixed_precision: + print(" > Mixed precision mode is ON") + + out_path = args.continue_path + if args.continue_path == "": + out_path = create_experiment_folder(c.output_path, c.run_name, + args.debug) + + audio_path = os.path.join(out_path, "test_audios") + + c_logger = ConsoleLogger() + + if args.rank == 0: + os.makedirs(audio_path, exist_ok=True) + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + copy_model_files(c, args.config_path, + out_path, new_fields) + os.chmod(audio_path, 0o775) + os.chmod(out_path, 0o775) + + log_path = out_path + + tb_logger = TensorboardLogger(log_path, model_name=model_class) + + # write model desc to tensorboard + tb_logger.tb_add_text("model-description", c["run_description"], 0) + + return c, out_path, audio_path, c_logger, tb_logger