reformated docstrings in arguments.py

This commit is contained in:
gerazov 2021-02-12 11:36:01 +01:00
parent 702dff3edc
commit 0e78e31dbf
1 changed files with 32 additions and 53 deletions

View File

@ -19,16 +19,11 @@ from TTS.tts.utils.generic_utils import check_config_tts
def parse_arguments(argv): def parse_arguments(argv):
"""Parse command line arguments of training scripts. """Parse command line arguments of training scripts.
Parameters Args:
---------- argv (list): This is a list of input arguments as given by sys.argv
argv : list
This is a list of input arguments as given by sys.argv
Returns
-------
argparse.Namespace
Parsed arguments.
Returns:
argparse.Namespace: Parsed arguments.
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -46,7 +41,8 @@ def parse_arguments(argv):
parser.add_argument( parser.add_argument(
"--best_path", "--best_path",
type=str, type=str,
help="Best model file to be used for extracting best loss.", help=("Best model file to be used for extracting best loss."
"If not specified, the latest best model in continue path is used"),
default="") default="")
parser.add_argument( parser.add_argument(
"--config_path", "--config_path",
@ -78,21 +74,14 @@ def get_last_models(path):
It is based on globbing for `*.pth.tar` and the RegEx It is based on globbing for `*.pth.tar` and the RegEx
`(checkpoint|best_model)_([0-9]+)`. `(checkpoint|best_model)_([0-9]+)`.
Parameters Args:
---------- path (list): Path to files to be compared.
path : list
Path to files to be compared.
Raises Raises:
------ ValueError: If no checkpoint or best_model files are found.
ValueError
If no checkpoint or best_model files are found.
Returns
-------
last_checkpoint : str
Last checkpoint filename.
Returns:
last_checkpoint (str): Last checkpoint filename.
""" """
file_names = glob.glob(os.path.join(path, "*.pth.tar")) file_names = glob.glob(os.path.join(path, "*.pth.tar"))
last_models = {} last_models = {}
@ -130,13 +119,10 @@ def get_last_models(path):
def process_args(args, model_type): def process_args(args, model_type):
"""Process parsed comand line arguments. """Process parsed comand line arguments.
Parameters Args:
---------- args (argparse.Namespace or dict like): Parsed input arguments.
args : argparse.Namespace or dict like model_type (str): Model type used to check config parameters and setup
Parsed input arguments. the TensorBoard logger. One of:
model_type : str
Model type used to check config parameters and setup the TensorBoard
logger. One of:
- tacotron - tacotron
- glow_tts - glow_tts
- speedy_speech - speedy_speech
@ -144,24 +130,17 @@ def process_args(args, model_type):
- wavegrad - wavegrad
- wavernn - wavernn
Raises Raises:
------ ValueError: If `model_type` is not one of implemented choices.
ValueError
If `model_type` is not one of implemented choices.
Returns
-------
c : TTS.utils.io.AttrDict
Config paramaters.
out_path : str
Path to save models and logging.
audio_path : str
Path to save generated test audios.
c_logger : TTS.utils.console_logger.ConsoleLogger
Class that does logging to the console.
tb_logger : TTS.utils.tensorboard.TensorboardLogger
Class that does the TensorBoard loggind.
Returns:
c (TTS.utils.io.AttrDict): Config paramaters.
out_path (str): Path to save models and logging.
audio_path (str): Path to save generated test audios.
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
logging to the console.
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
the TensorBoard loggind.
""" """
if args.continue_path: if args.continue_path:
args.output_path = args.continue_path args.output_path = args.continue_path