diff --git a/TTS/trainer.py b/TTS/trainer.py index 8589ae5c..d75b8e14 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -4,16 +4,14 @@ import importlib import multiprocessing import os import platform -import re import sys import time import traceback from argparse import Namespace from dataclasses import dataclass, field -from typing import Dict, List, Tuple, Union -from urllib.parse import urlparse +from inspect import signature +from typing import Callable, Dict, List, Tuple, Union -import fsspec import torch import torch.distributed as dist from coqpit import Coqpit @@ -21,11 +19,7 @@ from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader -from TTS.config import load_config, register_config -from TTS.tts.datasets import load_meta_data -from TTS.tts.models import setup_model as setup_tts_model -from TTS.tts.utils.text.symbols import parse_symbols -from TTS.utils.audio import AudioProcessor +from TTS.stt.datasets.tokenizer import Tokenizer from TTS.utils.callbacks import TrainerCallback from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import ( @@ -39,9 +33,13 @@ from TTS.utils.generic_utils import ( ) from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger -from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env -from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data -from TTS.vocoder.models import setup_model as setup_vocoder_model +from TTS.utils.trainer_utils import ( + get_last_checkpoint, + get_optimizer, + get_scheduler, + is_apex_available, + setup_torch_training_env, +) multiprocessing.set_start_method("fork") @@ -80,6 +78,9 @@ class TrainingArgs(Coqpit): "help": "Best model file to be used for extracting the best loss. If not specified, the latest best model in continue path is used" }, ) + skip_train_epoch: bool = field( + default=False, metadata={"help": "Run only evaluation iteration. Useful for debugging."} + ) config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) @@ -98,7 +99,14 @@ class Trainer: c_logger: ConsoleLogger = None, dashboard_logger: Union[TensorboardLogger, WandbLogger] = None, model: nn.Module = None, + get_model: Callable = None, + get_data_samples: Callable = None, + train_samples: List = None, + eval_samples: List = None, + tokenizer: Tokenizer = None, cudnn_benchmark: bool = False, + training_assets: Dict = {}, + parse_command_line_args: bool = True, ) -> None: """Simple yet powerful πŸΈπŸ’¬ TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models or easily be customized. @@ -127,24 +135,44 @@ class Trainer: model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` initializes a model from the provided config. Defaults to None. + get_model (Callable): + A function that returns a model. It is used to initialize the model when `model` is not provided. + It either takes the config as the only argument or does not take any argument. + Defaults to None + + get_data_samples (Callable): + A function that returns a list of training and evaluation samples. Used if `train_samples` and + `eval_samples` are None. Defaults to None. + + train_samples (List): + A list of training samples used by the model's `get_data_loader` to init the `dataset` and the + `data_loader`. Defaults to None. + + eval_samples (List): + A list of evaluation samples used by the model's `get_data_loader` to init the `dataset` and the + `data_loader`. Defaults to None. + cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input length is changing batch to batch along the training. + training_assets (Dict): + A dictionary of assets to be used at training and passed to the model's ```train_log(), eval_log(), get_data_loader()``` + during training. It can include `AudioProcessor` or/and `Tokenizer`. Defaults to {}. + + parse_command_line_args (bool): + If true, parse command-line arguments and update `TrainingArgs` and model `config` values. Set it + to false if you parse the arguments yourself. Defaults to True. + Examples: - Running trainer on a model. + Running trainer with HifiGAN model. >>> args = TrainingArgs(...) >>> config = HifiganConfig(...) >>> model = GANModel(config) - >>> trainer = Trainer(args, config, output_path, model=model) - >>> trainer.fit() - - Running trainer on a config. - - >>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,) - >>> args, config, output_path, _, c_logger, dashboard_logger = init_training(TrainingArgs(), config) - >>> trainer = Trainer(args, config, output_path, c_logger, dashboard_logger) + >>> ap = AudioProcessor(**config.audio) + >>> assets = {"audio_processor": ap} + >>> trainer = Trainer(args, config, output_path, model=model, training_assets=assets) >>> trainer.fit() TODO: @@ -154,20 +182,33 @@ class Trainer: - Profiler integration. - Overfitting to a batch. - TPU training + - NOTE: Consider moving `training_assets` to the model implementation. """ + if parse_command_line_args: + # parse command-line arguments for TrainingArgs() + args, coqpit_overrides = self.parse_argv(args) - if config is None: - # parse config from console arguments - config, output_path, _, c_logger, dashboard_logger = process_args(args) + # get ready for training and parse command-line arguments for the model config + config = self.init_training(args, coqpit_overrides, config) + # define the experiment path and create the folder + output_path = get_experiment_folder_path(config.output_path, config.run_name) + os.makedirs(output_path, exist_ok=True) + + # copy training assets to the output folder + copy_model_files(config, output_path, new_fields=None) + + # init class members self.args = args self.config = config self.output_path = output_path self.config.output_log_path = output_path + self.training_assets = training_assets # setup logging log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") self._setup_logger_config(log_file) + time.sleep(1.0) # wait for the logger to be ready # set and initialize Pytorch runtime self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp) @@ -196,33 +237,30 @@ class Trainer: self.use_apex = self._is_apex_available() self.use_amp_scaler = self.config.mixed_precision and self.use_cuda - # init audio processor - self.ap = AudioProcessor(**self.config.audio.to_dict()) + # init tokenizer + self.tokenizer = tokenizer # load data samples - # TODO: refactor this - if "datasets" in self.config: - # load data for `tts` models - self.data_train, self.data_eval = load_meta_data(self.config.datasets) - elif self.config.feature_path is not None: - # load pre-comnputed features for `vocoder`models - print(f" > Loading features from: {self.config.feature_path}") - self.data_eval, self.data_train = load_wav_feat_data( - self.config.data_path, self.config.feature_path, self.config.eval_split_size - ) + if train_samples is None and get_data_samples is None: + raise ValueError("[!] `train_samples` and `get_data_samples` cannot both be None.") + if train_samples is not None: + self.train_samples = train_samples + self.eval_samples = eval_samples else: - # load data for `vocoder`models - self.data_eval, self.data_train = load_wav_data(self.config.data_path, self.config.eval_split_size) + self.train_samples, self.eval_samples = self.run_get_data_samples(config, get_data_samples) # init TTS model + if model is None and get_model is None: + raise ValueError("[!] `model` and `get_model` cannot both be None.") if model is not None: self.model = model else: - self.model = self.get_model(self.config) + self.run_get_model(self.config, get_model) + # TODO: out! # init multispeaker settings of the model if hasattr(self.model, "init_multispeaker"): - self.model.init_multispeaker(self.config, self.data_train + self.data_eval) + self.model.init_multispeaker(self.config, self.train_samples + self.eval_samples) # setup criterion self.criterion = self.get_criterion(self.model) @@ -247,7 +285,7 @@ class Trainer: # setup optimizer self.optimizer = self.get_optimizer(self.model, self.config) - # callback + # CALLBACK self.callbacks = TrainerCallback(self) self.callbacks.on_init_start() @@ -280,7 +318,7 @@ class Trainer: else: self.scheduler.last_epoch = self.restore_step - # DISTRUBUTED + # DISTRIBUTED if self.num_gpus > 1: self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) @@ -291,8 +329,54 @@ class Trainer: self.callbacks.on_init_end() @staticmethod - def get_model(config: Coqpit) -> nn.Module: - """Initialize model from config. + def parse_argv(args: Union[Coqpit, List]): + """Parse command line arguments to init or override `TrainingArgs()`.""" + if isinstance(args, Coqpit): + parser = args.init_argparse(arg_prefix="") + else: + train_config = TrainingArgs() + parser = train_config.init_argparse(arg_prefix="") + training_args, coqpit_overrides = parser.parse_known_args() + args.parse_args(training_args) + return args, coqpit_overrides + + def init_training(self, args: TrainingArgs, coqpit_overrides: Dict, config: Coqpit = None): + """Initialize training and update model configs from command line arguments. + + Args: + args (argparse.Namespace or dict like): Parsed input arguments. + config_overrides (argparse.Namespace or dict like): Parsed config overriding arguments. + config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. + + Returns: + c (TTS.utils.io.AttrDict): Config paramaters. + """ + # set arguments for continuing training + if args.continue_path: + experiment_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + args.restore_path, best_model = get_last_checkpoint(args.continue_path) + if not args.best_path: + args.best_path = best_model + + # override config values from command-line args + # TODO: Maybe it is better to do it outside + if len(coqpit_overrides) > 0: + config.parse_known_args(coqpit_overrides, relaxed_parser=True) + experiment_path = args.continue_path + + # update the config.json fields and copy it to the output folder + if args.rank == 0: + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + copy_model_files(config, experiment_path, new_fields) + return config + + @staticmethod + def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module: + """Run the `get_model` function and return the model. Args: config (Coqpit): Model config. @@ -300,12 +384,23 @@ class Trainer: Returns: nn.Module: initialized model. """ - try: - model = setup_vocoder_model(config) - except ModuleNotFoundError: - model = setup_tts_model(config) + if len(signature(get_model).sig.parameters) == 1: + model = get_model(config) + else: + model = get_model() return model + @staticmethod + def run_get_data_samples(config: Coqpit, get_data_samples: Callable) -> nn.Module: + if isinstance(get_data_samples, Callable): + if len(signature(get_data_samples).sig.parameters) == 1: + train_samples, eval_samples = get_data_samples(config) + else: + train_samples, eval_samples = get_data_samples() + return train_samples, eval_samples + else: + return None, None + def restore_model( self, config: Coqpit, @@ -366,11 +461,15 @@ class Trainer: torch.cuda.empty_cache() return model, optimizer, scaler, restore_step + ######################### + # DATA LOADING FUNCTIONS + ######################### + def _get_loader( self, model: nn.Module, config: Coqpit, - ap: AudioProcessor, + assets: Dict, is_eval: bool, data_items: List, verbose: bool, @@ -379,14 +478,14 @@ class Trainer: if num_gpus > 1: if hasattr(model.module, "get_data_loader"): loader = model.module.get_data_loader( - config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank + config, assets, is_eval, data_items, verbose, num_gpus, self.args.rank ) else: if hasattr(model, "get_data_loader"): - loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) + loader = model.get_data_loader(config, assets, is_eval, data_items, verbose, num_gpus) return loader - def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader: + def get_train_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader: """Initialize and return a training data loader. Args: @@ -397,10 +496,10 @@ class Trainer: Returns: DataLoader: Initialized training data loader. """ - return self._get_loader(self.model, self.config, ap, False, data_items, verbose, self.num_gpus) + return self._get_loader(self.model, self.config, training_assets, False, data_items, verbose, self.num_gpus) - def get_eval_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader: - return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus) + def get_eval_dataloader(self, training_assets: Dict, data_items: List, verbose: bool) -> DataLoader: + return self._get_loader(self.model, self.config, training_assets, True, data_items, verbose, self.num_gpus) def format_batch(self, batch: List) -> Dict: """Format the dataloader output and return a batch. @@ -420,6 +519,10 @@ class Trainer: batch[k] = to_cuda(v) return batch + ###################### + # TRAIN FUNCTIONS + ###################### + @staticmethod def master_params(optimizer: torch.optim.Optimizer): """Generator over parameters owned by the optimizer. @@ -567,24 +670,6 @@ class Trainer: loss_dict["grad_norm"] = grad_norm return outputs, loss_dict, step_time - @staticmethod - def _detach_loss_dict(loss_dict: Dict) -> Dict: - """Detach loss values from autograp. - - Args: - loss_dict (Dict): losses. - - Returns: - Dict: losses detached from autograph. - """ - loss_dict_detached = {} - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_detached[key] = value - else: - loss_dict_detached[key] = value.item() - return loss_dict_detached - def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: """Perform a training step on a batch of inputs and log the process. @@ -700,15 +785,14 @@ class Trainer: self.dashboard_logger.log_artifact(self.output_path, "checkpoint", "model", aliases) # training visualizations - figures, audios = None, None if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"): - figures, audios = self.model.module.train_log(self.ap, batch, outputs) + self.model.module.train_log( + batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done + ) elif hasattr(self.model, "train_log"): - figures, audios = self.model.train_log(self.ap, batch, outputs) - if figures is not None: - self.dashboard_logger.train_figures(self.total_steps_done, figures) - if audios is not None: - self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate) + self.model.train_log( + batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done + ) self.dashboard_logger.flush() @@ -718,11 +802,13 @@ class Trainer: def train_epoch(self) -> None: """Main entry point for the training loop. Run training on the all training samples.""" + # initialize the data loader self.train_loader = self.get_train_dataloader( - self.ap, - self.data_train, + self.training_assets, + self.train_samples, verbose=True, ) + # set model to training mode if self.num_gpus > 1: self.model.module.train() else: @@ -734,11 +820,12 @@ class Trainer: batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) self.c_logger.print_train_start() loader_start_time = time.time() + # iterate over the training samples for cur_step, batch in enumerate(self.train_loader): _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) loader_start_time = time.time() epoch_time = time.time() - epoch_start_time - # Plot self.epochs_done Stats + # plot self.epochs_done Stats if self.args.rank == 0: epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(self.keep_avg_train.avg_values) @@ -754,6 +841,10 @@ class Trainer: else: self.scheduler.step() + ####################### + # EVAL FUNCTIONS + ####################### + @staticmethod def _model_eval_step( batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None @@ -803,7 +894,7 @@ class Trainer: loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss") loss_dict.update(loss_dict_new) - loss_dict = self._detach_loss_dict(loss_dict) + loss_dict = self._detach_loss_dict(loss_dict) # update avg stats update_eval_values = {} @@ -819,8 +910,8 @@ class Trainer: """Main entry point for the evaluation loop. Run evaluation on the all validation samples.""" self.eval_loader = ( self.get_eval_dataloader( - self.ap, - self.data_eval, + self.training_assets, + self.eval_samples, verbose=True, ) if self.config.run_eval @@ -840,15 +931,12 @@ class Trainer: loader_start_time = time.time() # plot epoch stats, artifacts and figures if self.args.rank == 0: - figures, audios = None, None if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"): - figures, audios = self.model.module.eval_log(self.ap, batch, outputs) + self.model.module.eval_log( + batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done + ) elif hasattr(self.model, "eval_log"): - figures, audios = self.model.eval_log(self.ap, batch, outputs) - if figures is not None: - self.dashboard_logger.eval_figures(self.total_steps_done, figures) - if audios is not None: - self.dashboard_logger.eval_audios(self.total_steps_done, audios, self.ap.sample_rate) + self.model.eval_log(batch, outputs, self.dashboard_logger, self.training_assets, self.total_steps_done) self.dashboard_logger.eval_stats(self.total_steps_done, self.keep_avg_eval.avg_values) def test_run(self) -> None: @@ -857,22 +945,22 @@ class Trainer: if hasattr(self.model, "test_run") or (self.num_gpus > 1 and hasattr(self.model.module, "test_run")): if self.eval_loader is None: self.eval_loader = self.get_eval_dataloader( - self.ap, - self.data_eval, + self.training_assets, + self.eval_samples, verbose=True, ) if hasattr(self.eval_loader.dataset, "load_test_samples"): samples = self.eval_loader.dataset.load_test_samples(1) if self.num_gpus > 1: - figures, audios = self.model.module.test_run(self.ap, samples, None) + figures, audios = self.model.module.test_run(self.training_assets, samples, None) else: - figures, audios = self.model.test_run(self.ap, samples, None) + figures, audios = self.model.test_run(self.training_assets, samples, None) else: if self.num_gpus > 1: - figures, audios = self.model.module.test_run(self.ap) + figures, audios = self.model.module.test_run(self.training_assets) else: - figures, audios = self.model.test_run(self.ap) + figures, audios = self.model.test_run(self.training_assets) self.dashboard_logger.test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) self.dashboard_logger.test_figures(self.total_steps_done, figures) @@ -886,6 +974,10 @@ class Trainer: self.best_loss = ch["model_loss"] print(f" > Starting with loaded last best loss {self.best_loss}.") + ################################### + # FIT FUNCTIONS + ################################### + def _fit(self) -> None: """πŸƒ train -> evaluate -> test for the number of epochs.""" self._restore_best_loss() @@ -901,7 +993,8 @@ class Trainer: self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.epochs_done = epoch self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path) - self.train_epoch() + if not self.args.skip_train_epoch: + self.train_epoch() if self.config.run_eval: self.eval_epoch() if epoch >= self.config.test_delay_epochs and self.args.rank <= 0: @@ -939,24 +1032,6 @@ class Trainer: traceback.print_exc() sys.exit(1) - def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: - """Pick the target loss to compare models""" - target_avg_loss = None - - # return if target loss defined in the model config - if "target_loss" in self.config and self.config.target_loss: - return keep_avg_target[f"avg_{self.config.target_loss}"] - - # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers - if isinstance(self.optimizer, list): - target_avg_loss = 0 - for idx in range(len(self.optimizer)): - target_avg_loss += keep_avg_target[f"avg_loss_{idx}"] - target_avg_loss /= len(self.optimizer) - else: - target_avg_loss = keep_avg_target["avg_loss"] - return target_avg_loss - def save_best_model(self) -> None: """Save the best model. It only saves if the current target loss is smaller then the previous.""" @@ -978,35 +1053,9 @@ class Trainer: keep_after=self.config.keep_after, ) - def _setup_logger_config(self, log_file: str) -> None: - """Write log strings to a file and print logs to the terminal. - TODO: Causes formatting issues in pdb debugging.""" - - class Logger(object): - def __init__(self, print_to_terminal=True): - self.print_to_terminal = print_to_terminal - self.terminal = sys.stdout - self.log_file = log_file - - def write(self, message): - if self.print_to_terminal: - self.terminal.write(message) - with open(self.log_file, "a", encoding="utf-8") as f: - f.write(message) - - def flush(self): - # this flush method is needed for python 3 compatibility. - # this handles the flush command by doing nothing. - # you might want to specify some extra behavior here. - pass - - # don't let processes rank > 0 write to the terminal - sys.stdout = Logger(self.args.rank == 0) - - @staticmethod - def _is_apex_available() -> bool: - """Check if Nvidia's APEX is available.""" - return importlib.util.find_spec("apex") is not None + ##################### + # GET FUNCTIONS + ##################### @staticmethod def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]: @@ -1084,154 +1133,72 @@ class Trainer: criterion = model.get_criterion() return criterion + #################### + # HELPER FUNCTIONS + #################### -def getarguments(): - train_config = TrainingArgs() - parser = train_config.init_argparse(arg_prefix="") - return parser + @staticmethod + def _detach_loss_dict(loss_dict: Dict) -> Dict: + """Detach loss values from autograp. + Args: + loss_dict (Dict): losses. -def get_last_checkpoint(path: str) -> Tuple[str, str]: - """Get latest checkpoint or/and best model in path. + Returns: + Dict: losses detached from autograph. + """ + loss_dict_detached = {} + for key, value in loss_dict.items(): + if isinstance(value, (int, float)): + loss_dict_detached[key] = value + else: + loss_dict_detached[key] = value.detach() + return loss_dict_detached - It is based on globbing for `*.pth.tar` and the RegEx - `(checkpoint|best_model)_([0-9]+)`. + def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: + """Pick the target loss to compare models""" + target_avg_loss = None - Args: - path: Path to files to be compared. + # return if target loss defined in the model config + if "target_loss" in self.config and self.config.target_loss: + return keep_avg_target[f"avg_{self.config.target_loss}"] - Raises: - ValueError: If no checkpoint or best_model files are found. - - Returns: - Path to the last checkpoint - Path to best checkpoint - """ - fs = fsspec.get_mapper(path).fs - file_names = fs.glob(os.path.join(path, "*.pth.tar")) - scheme = urlparse(path).scheme - if scheme: # scheme is not preserved in fs.glob, add it back - file_names = [scheme + "://" + file_name for file_name in file_names] - last_models = {} - last_model_nums = {} - for key in ["checkpoint", "best_model"]: - last_model_num = None - last_model = None - # pass all the checkpoint files and find - # the one with the largest model number suffix. - for file_name in file_names: - match = re.search(f"{key}_([0-9]+)", file_name) - if match is not None: - model_num = int(match.groups()[0]) - if last_model_num is None or model_num > last_model_num: - last_model_num = model_num - last_model = file_name - - # if there is no checkpoint found above - # find the checkpoint with the latest - # modification date. - key_file_names = [fn for fn in file_names if key in fn] - if last_model is None and len(key_file_names) > 0: - last_model = max(key_file_names, key=os.path.getctime) - last_model_num = load_fsspec(last_model)["step"] - - if last_model is not None: - last_models[key] = last_model - last_model_nums[key] = last_model_num - - # check what models were found - if not last_models: - raise ValueError(f"No models found in continue path {path}!") - if "checkpoint" not in last_models: # no checkpoint just best model - last_models["checkpoint"] = last_models["best_model"] - elif "best_model" not in last_models: # no best model - # this shouldn't happen, but let's handle it just in case - last_models["best_model"] = last_models["checkpoint"] - # finally check if last best model is more recent than checkpoint - elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: - last_models["checkpoint"] = last_models["best_model"] - - return last_models["checkpoint"], last_models["best_model"] - - -def process_args(args, config=None): - """Process parsed comand line arguments and initialize the config if not provided. - - Args: - args (argparse.Namespace or dict like): Parsed input arguments. - config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. - - Returns: - c (TTS.utils.io.AttrDict): Config paramaters. - out_path (str): Path to save models and logging. - audio_path (str): Path to save generated test audios. - c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does - logging to the console. - - dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging - - TODO: - - Interactive config definition. - """ - if isinstance(args, tuple): - args, coqpit_overrides = args - if args.continue_path: - # continue a previous training from its output folder - experiment_path = args.continue_path - args.config_path = os.path.join(args.continue_path, "config.json") - args.restore_path, best_model = get_last_checkpoint(args.continue_path) - if not args.best_path: - args.best_path = best_model - # init config if not already defined - if config is None: - if args.config_path: - # init from a file - config = load_config(args.config_path) + # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers + if isinstance(self.optimizer, list): + target_avg_loss = 0 + for idx in range(len(self.optimizer)): + target_avg_loss += keep_avg_target[f"avg_loss_{idx}"] + target_avg_loss /= len(self.optimizer) else: - # init from console args - from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + target_avg_loss = keep_avg_target["avg_loss"] + return target_avg_loss - config_base = BaseTrainingConfig() - config_base.parse_known_args(coqpit_overrides) - config = register_config(config_base.model)() - # override values from command-line args - config.parse_known_args(coqpit_overrides, relaxed_parser=True) - experiment_path = args.continue_path - if not experiment_path: - experiment_path = get_experiment_folder_path(config.output_path, config.run_name) - audio_path = os.path.join(experiment_path, "test_audios") - config.output_log_path = experiment_path - # setup rank 0 process in distributed training - dashboard_logger = None - if args.rank == 0: - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - # if model characters are not set in the config file - # save the default set to the config file for future - # compatibility. - if config.has("characters") and config.characters is None: - used_characters = parse_symbols() - new_fields["characters"] = used_characters - copy_model_files(config, experiment_path, new_fields) - dashboard_logger = init_dashboard_logger(config) - c_logger = ConsoleLogger() - return config, experiment_path, audio_path, c_logger, dashboard_logger + def _setup_logger_config(self, log_file: str) -> None: + """Write log strings to a file and print logs to the terminal. + TODO: Causes formatting issues in pdb debugging.""" + class Logger(object): + def __init__(self, print_to_terminal=True): + self.print_to_terminal = print_to_terminal + self.terminal = sys.stdout + self.log_file = log_file -def init_arguments(): - train_config = TrainingArgs() - parser = train_config.init_argparse(arg_prefix="") - return parser + def write(self, message): + if self.print_to_terminal: + self.terminal.write(message) + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(message) + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass -def init_training(argv: Union[List, Coqpit], config: Coqpit = None): - """Initialization of a training run.""" - if isinstance(argv, Coqpit): - parser = argv.init_argparse(arg_prefix="") - else: - parser = init_arguments() - args = parser.parse_known_args() - config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config) - return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger + # don't let processes rank > 0 write to the terminal + sys.stdout = Logger(self.args.rank == 0) + + @staticmethod + def _is_apex_available() -> bool: + """Check if Nvidia's APEX is available.""" + return importlib.util.find_spec("apex") is not None diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/dataset.py similarity index 100% rename from TTS/tts/datasets/TTSDataset.py rename to TTS/tts/datasets/dataset.py diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 06c7cb2b..0c9f60e8 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.model import BaseModel -from TTS.tts.datasets import TTSDataset +from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text import make_symbols @@ -32,6 +32,30 @@ class BaseTTS(BaseModel): - 1D tensors `batch x 1` """ + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + if "characters" in config: + _, self.config, num_chars = self.get_characters(config) + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + self.config = config + self.args = config.model_args + elif "Args" in config.__class__.__name__: + self.args = config + else: + raise ValueError("config must be either a *Config or *Args") + @staticmethod def get_characters(config: Coqpit) -> str: # TODO: implement CharacterProcessor @@ -169,7 +193,7 @@ class BaseTTS(BaseModel): def get_data_loader( self, config: Coqpit, - ap: AudioProcessor, + assets: Dict, is_eval: bool, data_items: List, verbose: bool, @@ -179,6 +203,8 @@ class BaseTTS(BaseModel): if is_eval and not config.run_eval: loader = None else: + ap = assets["audio_processor"] + # setup multi-speaker attributes if hasattr(self, "speaker_manager"): speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None @@ -280,14 +306,18 @@ class BaseTTS(BaseModel): ) return loader - def test_run(self, ap) -> Tuple[Dict, Dict]: + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: """Generic test run for `tts` models used by `Trainer`. You can override this for a different behaviour. + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + Returns: Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. """ + ap = assets["audio_processor"] print(" | > Synthesizing test sentences.") test_audios = {} test_figures = {} diff --git a/TTS/utils/trainer_utils.py b/TTS/utils/trainer_utils.py index 005114d1..dabb33cd 100644 --- a/TTS/utils/trainer_utils.py +++ b/TTS/utils/trainer_utils.py @@ -1,8 +1,13 @@ import importlib +import os +import re from typing import Dict, List, Tuple +from urllib.parse import urlparse +import fsspec import torch +from TTS.utils.io import load_fsspec from TTS.utils.training import NoamLR @@ -80,3 +85,66 @@ def get_optimizer( if model is not None: parameters = model.parameters() return optimizer(parameters, lr=lr, **optimizer_params) + + +def get_last_checkpoint(path: str) -> Tuple[str, str]: + """Get latest checkpoint or/and best model in path. + + It is based on globbing for `*.pth.tar` and the RegEx + `(checkpoint|best_model)_([0-9]+)`. + + Args: + path: Path to files to be compared. + + Raises: + ValueError: If no checkpoint or best_model files are found. + + Returns: + Path to the last checkpoint + Path to best checkpoint + """ + fs = fsspec.get_mapper(path).fs + file_names = fs.glob(os.path.join(path, "*.pth.tar")) + scheme = urlparse(path).scheme + if scheme: # scheme is not preserved in fs.glob, add it back + file_names = [scheme + "://" + file_name for file_name in file_names] + last_models = {} + last_model_nums = {} + for key in ["checkpoint", "best_model"]: + last_model_num = None + last_model = None + # pass all the checkpoint files and find + # the one with the largest model number suffix. + for file_name in file_names: + match = re.search(f"{key}_([0-9]+)", file_name) + if match is not None: + model_num = int(match.groups()[0]) + if last_model_num is None or model_num > last_model_num: + last_model_num = model_num + last_model = file_name + + # if there is no checkpoint found above + # find the checkpoint with the latest + # modification date. + key_file_names = [fn for fn in file_names if key in fn] + if last_model is None and len(key_file_names) > 0: + last_model = max(key_file_names, key=os.path.getctime) + last_model_num = load_fsspec(last_model)["step"] + + if last_model is not None: + last_models[key] = last_model + last_model_nums[key] = last_model_num + + # check what models were found + if not last_models: + raise ValueError(f"No models found in continue path {path}!") + if "checkpoint" not in last_models: # no checkpoint just best model + last_models["checkpoint"] = last_models["best_model"] + elif "best_model" not in last_models: # no best model + # this shouldn't happen, but let's handle it just in case + last_models["best_model"] = last_models["checkpoint"] + # finally check if last best model is more recent than checkpoint + elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: + last_models["checkpoint"] = last_models["best_model"] + + return last_models["checkpoint"], last_models["best_model"] diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py index f879cd42..9d6ef26f 100644 --- a/TTS/vocoder/models/base_vocoder.py +++ b/TTS/vocoder/models/base_vocoder.py @@ -1,3 +1,5 @@ +from coqpit import Coqpit + from TTS.model import BaseModel # pylint: skip-file @@ -16,5 +18,35 @@ class BaseVocoder(BaseModel): - 1D tensors `batch x 1` """ - def __init__(self): - super().__init__() + def __init__(self, config): + super().__init__(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + if "characters" in config: + _, self.config, num_chars = self.get_characters(config) + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + if "model_args" in config: + self.args = self.config.model_args + # This is for backward compatibility + if "model_params" in config: + self.args = self.config.model_params + else: + self.config = config + if "model_args" in config: + self.args = self.config.model_args + # This is for backward compatibility + if "model_params" in config: + self.args = self.config.model_params + else: + raise ValueError("config must be either a *Config or *Args")