reformated docstrings in arguments.py

This commit is contained in:
gerazov 2021-02-12 11:36:01 +01:00
parent 702dff3edc
commit 0e78e31dbf
1 changed files with 32 additions and 53 deletions

View File

@ -19,16 +19,11 @@ from TTS.tts.utils.generic_utils import check_config_tts
def parse_arguments(argv):
"""Parse command line arguments of training scripts.
Parameters
----------
argv : list
This is a list of input arguments as given by sys.argv
Returns
-------
argparse.Namespace
Parsed arguments.
Args:
argv (list): This is a list of input arguments as given by sys.argv
Returns:
argparse.Namespace: Parsed arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
@ -46,7 +41,8 @@ def parse_arguments(argv):
parser.add_argument(
"--best_path",
type=str,
help="Best model file to be used for extracting best loss.",
help=("Best model file to be used for extracting best loss."
"If not specified, the latest best model in continue path is used"),
default="")
parser.add_argument(
"--config_path",
@ -78,21 +74,14 @@ def get_last_models(path):
It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`.
Parameters
----------
path : list
Path to files to be compared.
Args:
path (list): Path to files to be compared.
Raises
------
ValueError
If no checkpoint or best_model files are found.
Returns
-------
last_checkpoint : str
Last checkpoint filename.
Raises:
ValueError: If no checkpoint or best_model files are found.
Returns:
last_checkpoint (str): Last checkpoint filename.
"""
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
last_models = {}
@ -130,38 +119,28 @@ def get_last_models(path):
def process_args(args, model_type):
"""Process parsed comand line arguments.
Parameters
----------
args : argparse.Namespace or dict like
Parsed input arguments.
model_type : str
Model type used to check config parameters and setup the TensorBoard
logger. One of:
- tacotron
- glow_tts
- speedy_speech
- gan
- wavegrad
- wavernn
Args:
args (argparse.Namespace or dict like): Parsed input arguments.
model_type (str): Model type used to check config parameters and setup
the TensorBoard logger. One of:
- tacotron
- glow_tts
- speedy_speech
- gan
- wavegrad
- wavernn
Raises
------
ValueError
If `model_type` is not one of implemented choices.
Returns
-------
c : TTS.utils.io.AttrDict
Config paramaters.
out_path : str
Path to save models and logging.
audio_path : str
Path to save generated test audios.
c_logger : TTS.utils.console_logger.ConsoleLogger
Class that does logging to the console.
tb_logger : TTS.utils.tensorboard.TensorboardLogger
Class that does the TensorBoard loggind.
Raises:
ValueError: If `model_type` is not one of implemented choices.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard loggind.
"""
if args.continue_path:
args.output_path = args.continue_path