mirror of https://github.com/coqui-ai/TTS.git
Update `train_tts.py` and `train_vocoder.py`
This commit is contained in:
parent
2e9b6b4f90
commit
ba2b8c827f
|
@ -1,12 +1,62 @@
|
|||
import sys
|
||||
import os
|
||||
|
||||
from TTS.trainer import Trainer, init_training
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def main():
|
||||
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger, cudnn_benchmark=False)
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
# init trainer args
|
||||
train_args = TrainingArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
||||
# override trainer args from comman-line args
|
||||
args, config_overrides = parser.parse_known_args()
|
||||
train_args.parse_args(args)
|
||||
|
||||
# load config.json and register
|
||||
if args.config_path or args.continue_path:
|
||||
if args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
if len(config_overrides) > 0:
|
||||
config.parse_known_args(config_overrides, relaxed_parser=True)
|
||||
elif args.continue_path:
|
||||
# continue from a prev experiment
|
||||
config = load_config(os.path.join(args.continue_path, "config.json"))
|
||||
if len(config_overrides) > 0:
|
||||
config.parse_known_args(config_overrides, relaxed_parser=True)
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(config_overrides)
|
||||
config = register_config(config_base.model)()
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# init the model from config
|
||||
model = setup_model(config)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
train_args,
|
||||
config,
|
||||
config.output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
parse_command_line_args=False,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
|
|
|
@ -1,26 +1,69 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from TTS.trainer import Trainer, init_training
|
||||
from TTS.utils.generic_utils import remove_experiment_folder
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
args, config, output_path, _, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, dashboard_logger)
|
||||
trainer.fit()
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(output_path)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(output_path)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
"""Run `tts` model training directly by a `config.json` file."""
|
||||
# init trainer args
|
||||
train_args = TrainingArgs()
|
||||
parser = train_args.init_argparse(arg_prefix="")
|
||||
|
||||
# override trainer args from comman-line args
|
||||
args, config_overrides = parser.parse_known_args()
|
||||
train_args.parse_args(args)
|
||||
|
||||
# load config.json and register
|
||||
if args.config_path or args.continue_path:
|
||||
if args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
if len(config_overrides) > 0:
|
||||
config.parse_known_args(config_overrides, relaxed_parser=True)
|
||||
elif args.continue_path:
|
||||
# continue from a prev experiment
|
||||
config = load_config(os.path.join(args.continue_path, "config.json"))
|
||||
if len(config_overrides) > 0:
|
||||
config.parse_known_args(config_overrides, relaxed_parser=True)
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(config_overrides)
|
||||
config = register_config(config_base.model)()
|
||||
|
||||
# load training samples
|
||||
if "feature_path" in config and config.feature_path:
|
||||
# load pre-computed features
|
||||
print(f" > Loading features from: {config.feature_path}")
|
||||
eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size)
|
||||
else:
|
||||
# load data raw wav files
|
||||
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
||||
# init the model from config
|
||||
model = setup_model(config)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
train_args,
|
||||
config,
|
||||
config.output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
parse_command_line_args=False,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue