mirror of https://github.com/coqui-ai/TTS.git
Refactor Speaker Encoder training
This commit is contained in:
parent
043dca61b4
commit
2e9b6b4f90
|
@ -12,9 +12,9 @@ from torch.utils.data import DataLoader
|
|||
from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
||||
from TTS.speaker_encoder.utils.training import init_training
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.trainer import init_training
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
@ -156,7 +156,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=False)
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
|
||||
|
||||
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
|
||||
|
||||
|
@ -208,7 +208,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training(sys.argv)
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training()
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
import os
|
||||
from typing import List, Union
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.trainer import TrainingArgs
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.logging import init_dashboard_logger
|
||||
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||
from TTS.utils.trainer_utils import get_last_checkpoint
|
||||
|
||||
|
||||
def getarguments():
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def process_args(args, config=None):
|
||||
"""Process parsed comand line arguments and initialize the config if not provided.
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None.
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
out_path (str): Path to save models and logging.
|
||||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
logging to the console.
|
||||
dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging
|
||||
TODO:
|
||||
- Interactive config definition.
|
||||
"""
|
||||
if isinstance(args, tuple):
|
||||
args, coqpit_overrides = args
|
||||
if args.continue_path:
|
||||
# continue a previous training from its output folder
|
||||
experiment_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
# init config if not already defined
|
||||
if config is None:
|
||||
if args.config_path:
|
||||
# init from a file
|
||||
config = load_config(args.config_path)
|
||||
else:
|
||||
# init from console args
|
||||
from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel
|
||||
|
||||
config_base = BaseTrainingConfig()
|
||||
config_base.parse_known_args(coqpit_overrides)
|
||||
config = register_config(config_base.model)()
|
||||
# override values from command-line args
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
config.output_log_path = experiment_path
|
||||
# setup rank 0 process in distributed training
|
||||
dashboard_logger = None
|
||||
if args.rank == 0:
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if config.has("characters") and config.characters is None:
|
||||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
dashboard_logger = init_dashboard_logger(config)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
|
||||
|
||||
def init_arguments():
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def init_training(config: Coqpit = None):
|
||||
"""Initialization of a training run."""
|
||||
parser = init_arguments()
|
||||
args = parser.parse_known_args()
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger
|
Loading…
Reference in New Issue