reformated docstrings in arguments.py

This commit is contained in:
gerazov 2021-02-12 11:36:01 +01:00 committed by Eren Gölge
parent 62147994d4
commit 8cefa76bae
1 changed files with 19 additions and 29 deletions

View File

@ -18,16 +18,11 @@ from TTS.utils.tensorboard_logger import TensorboardLogger
def parse_arguments(argv): def parse_arguments(argv):
"""Parse command line arguments of training scripts. """Parse command line arguments of training scripts.
Parameters Args:
---------- argv (list): This is a list of input arguments as given by sys.argv
argv : list
This is a list of input arguments as given by sys.argv
Returns
-------
argparse.Namespace
Parsed arguments.
Returns:
argparse.Namespace: Parsed arguments.
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -45,7 +40,8 @@ def parse_arguments(argv):
parser.add_argument( parser.add_argument(
"--best_path", "--best_path",
type=str, type=str,
help="Best model file to be used for extracting best loss.", help=("Best model file to be used for extracting best loss."
"If not specified, the latest best model in continue path is used"),
default="") default="")
parser.add_argument( parser.add_argument(
"--config_path", "--config_path",
@ -77,21 +73,14 @@ def get_last_models(path):
It is based on globbing for `*.pth.tar` and the RegEx It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`. `(checkpoint|best_model)_([0-9]+)`.
Parameters Args:
---------- path (list): Path to files to be compared.
path : list
Path to files to be compared.
Raises Raises:
------ ValueError: If no checkpoint or best_model files are found.
ValueError
If no checkpoint or best_model files are found.
Returns
-------
last_checkpoint : str
Last checkpoint filename.
Returns:
last_checkpoint (str): Last checkpoint filename.
""" """
file_names = glob.glob(os.path.join(path, "*.pth.tar")) file_names = glob.glob(os.path.join(path, "*.pth.tar"))
last_models = {} last_models = {}
@ -131,8 +120,8 @@ def process_args(args, model_type):
Args: Args:
args (argparse.Namespace or dict like): Parsed input arguments. args (argparse.Namespace or dict like): Parsed input arguments.
model_type (str): Model type used to check config parameters and setup the TensorBoard model_type (str): Model type used to check config parameters and setup
logger. One of: the TensorBoard logger. One of:
- tacotron - tacotron
- glow_tts - glow_tts
- speedy_speech - speedy_speech
@ -141,15 +130,16 @@ def process_args(args, model_type):
- wavernn - wavernn
Raises: Raises:
ValueError ValueError: If `model_type` is not one of implemented choices.
If `model_type` is not one of implemented choices.
Returns: Returns:
c (TTS.utils.io.AttrDict): Config paramaters. c (TTS.utils.io.AttrDict): Config paramaters.
out_path (str): Path to save models and logging. out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios. audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console. c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard loggind. logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard loggind.
""" """
if args.continue_path: if args.continue_path:
args.output_path = args.continue_path args.output_path = args.continue_path