mirror of https://github.com/coqui-ai/TTS.git
197 lines
6.8 KiB
Python
197 lines
6.8 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""Argument parser for training scripts."""
|
|
|
|
import argparse
|
|
import glob
|
|
import os
|
|
import re
|
|
|
|
import torch
|
|
from TTS.tts.utils.text.symbols import parse_symbols
|
|
from TTS.utils.console_logger import ConsoleLogger
|
|
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
|
from TTS.utils.io import copy_model_files, load_config
|
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
|
|
|
|
|
def parse_arguments(argv):
|
|
"""Parse command line arguments of training scripts.
|
|
|
|
Args:
|
|
argv (list): This is a list of input arguments as given by sys.argv
|
|
|
|
Returns:
|
|
argparse.Namespace: Parsed arguments.
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--continue_path",
|
|
type=str,
|
|
help=("Training output folder to continue training. Used to continue "
|
|
"a training. If it is used, 'config_path' is ignored."),
|
|
default="",
|
|
required="--config_path" not in argv)
|
|
parser.add_argument(
|
|
"--restore_path",
|
|
type=str,
|
|
help="Model file to be restored. Use to finetune a model.",
|
|
default="")
|
|
parser.add_argument(
|
|
"--best_path",
|
|
type=str,
|
|
help=("Best model file to be used for extracting best loss."
|
|
"If not specified, the latest best model in continue path is used"),
|
|
default="")
|
|
parser.add_argument(
|
|
"--config_path",
|
|
type=str,
|
|
help="Path to config file for training.",
|
|
required="--continue_path" not in argv)
|
|
parser.add_argument(
|
|
"--debug",
|
|
type=bool,
|
|
default=False,
|
|
help="Do not verify commit integrity to run training.")
|
|
parser.add_argument(
|
|
"--rank",
|
|
type=int,
|
|
default=0,
|
|
help="DISTRIBUTED: process rank for distributed training.")
|
|
parser.add_argument(
|
|
"--group_id",
|
|
type=str,
|
|
default="",
|
|
help="DISTRIBUTED: process group id.")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_last_checkpoint(path):
|
|
"""Get latest checkpoint or/and best model in path.
|
|
|
|
It is based on globbing for `*.pth.tar` and the RegEx
|
|
`(checkpoint|best_model)_([0-9]+)`.
|
|
|
|
Args:
|
|
path (list): Path to files to be compared.
|
|
|
|
Raises:
|
|
ValueError: If no checkpoint or best_model files are found.
|
|
|
|
Returns:
|
|
last_checkpoint (str): Last checkpoint filename.
|
|
"""
|
|
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
|
|
last_models = {}
|
|
last_model_nums = {}
|
|
for key in ['checkpoint', 'best_model']:
|
|
last_model_num = None
|
|
last_model = None
|
|
# pass all the checkpoint files and find
|
|
# the one with the largest model number suffix.
|
|
for file_name in file_names:
|
|
match = re.search(f"{key}_([0-9]+)", file_name)
|
|
if match is not None:
|
|
model_num = int(match.groups()[0])
|
|
if last_model_num is None or model_num > last_model_num:
|
|
last_model_num = model_num
|
|
last_model = file_name
|
|
|
|
# if there is not checkpoint found above
|
|
# find the checkpoint with the latest
|
|
# modification date.
|
|
key_file_names = [fn for fn in file_names if key in fn]
|
|
if last_model is None and len(key_file_names) > 0:
|
|
last_model = max(key_file_names, key=os.path.getctime)
|
|
last_model_num = torch.load(last_model)['step']
|
|
|
|
if last_model is not None:
|
|
last_models[key] = last_model
|
|
last_model_nums[key] = last_model_num
|
|
|
|
# check what models were found
|
|
if not last_models:
|
|
raise ValueError(f"No models found in continue path {path}!")
|
|
if 'checkpoint' not in last_models: # no checkpoint just best model
|
|
last_models['checkpoint'] = last_models['best_model']
|
|
elif 'best_model' not in last_models: # no best model
|
|
# this shouldn't happen, but let's handle it just in case
|
|
last_models['best_model'] = None
|
|
# finally check if last best model is more recent than checkpoint
|
|
elif last_model_nums['best_model'] > last_model_nums['checkpoint']:
|
|
last_models['checkpoint'] = last_models['best_model']
|
|
|
|
return last_models['checkpoint'], last_models['best_model']
|
|
|
|
|
|
def process_args(args, model_class):
|
|
"""Process parsed comand line arguments based on model class (tts or vocoder).
|
|
|
|
Args:
|
|
args (argparse.Namespace or dict like): Parsed input arguments.
|
|
model_type (str): Model type used to check config parameters and setup
|
|
the TensorBoard logger. One of ['tts', 'vocoder'].
|
|
|
|
Raises:
|
|
ValueError: If `model_type` is not one of implemented choices.
|
|
|
|
Returns:
|
|
c (TTS.utils.io.AttrDict): Config paramaters.
|
|
out_path (str): Path to save models and logging.
|
|
audio_path (str): Path to save generated test audios.
|
|
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
|
logging to the console.
|
|
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
|
the TensorBoard loggind.
|
|
"""
|
|
if args.continue_path:
|
|
args.output_path = args.continue_path
|
|
args.config_path = os.path.join(args.continue_path, "config.json")
|
|
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
|
if not args.best_path:
|
|
args.best_path = best_model
|
|
|
|
# setup output paths and read configs
|
|
c = load_config(args.config_path)
|
|
_ = os.path.dirname(os.path.realpath(__file__))
|
|
|
|
if 'mixed_precision' in c and c.mixed_precision:
|
|
print(" > Mixed precision mode is ON")
|
|
|
|
out_path = args.continue_path
|
|
if not out_path:
|
|
out_path = create_experiment_folder(c.output_path, c.run_name,
|
|
args.debug)
|
|
|
|
audio_path = os.path.join(out_path, "test_audios")
|
|
|
|
c_logger = ConsoleLogger()
|
|
tb_logger = None
|
|
|
|
if args.rank == 0:
|
|
os.makedirs(audio_path, exist_ok=True)
|
|
new_fields = {}
|
|
if args.restore_path:
|
|
new_fields["restore_path"] = args.restore_path
|
|
new_fields["github_branch"] = get_git_branch()
|
|
# if model characters are not set in the config file
|
|
# save the default set to the config file for future
|
|
# compatibility.
|
|
if model_class == 'tts' and 'characters' not in c:
|
|
used_characters = parse_symbols()
|
|
new_fields['characters'] = used_characters
|
|
copy_model_files(c, args.config_path,
|
|
out_path, new_fields)
|
|
os.chmod(audio_path, 0o775)
|
|
os.chmod(out_path, 0o775)
|
|
|
|
log_path = out_path
|
|
|
|
tb_logger = TensorboardLogger(log_path, model_name=model_class.upper())
|
|
|
|
# write model desc to tensorboard
|
|
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
|
|
|
return c, out_path, audio_path, c_logger, tb_logger
|