diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 863bd3b9..cfd092f1 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,12 +1,62 @@ -import sys +import os -from TTS.trainer import Trainer, init_training +from TTS.config import load_config, register_config +from TTS.trainer import Trainer, TrainingArgs +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models import setup_model +from TTS.utils.audio import AudioProcessor def main(): - """Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```""" - args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv) - trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False) + """Run `tts` model training directly by a `config.json` file.""" + # init trainer args + train_args = TrainingArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # init the model from config + model = setup_model(config) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, + parse_command_line_args=False, + ) trainer.fit() diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 000083e0..cd665f29 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,26 +1,69 @@ import os -import sys -import traceback -from TTS.trainer import Trainer, init_training -from TTS.utils.generic_utils import remove_experiment_folder +from TTS.config import load_config, register_config +from TTS.trainer import Trainer, TrainingArgs +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.models import setup_model def main(): - try: - args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv) - trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) - trainer.fit() - except KeyboardInterrupt: - remove_experiment_folder(output_path) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(output_path) - traceback.print_exc() - sys.exit(1) + """Run `tts` model training directly by a `config.json` file.""" + # init trainer args + train_args = TrainingArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + if "feature_path" in config and config.feature_path: + # load pre-computed features + print(f" > Loading features from: {config.feature_path}") + eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size) + else: + # load data raw wav files + eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # init the model from config + model = setup_model(config) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, + parse_command_line_args=False, + ) + trainer.fit() if __name__ == "__main__":