From bf562cf437b5036c37188f023bb5bbdcda7fed93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 7 Aug 2021 21:30:07 +0000 Subject: [PATCH] Update `trainer.py` Fix multi-speaker initialization of models. Add changes for end2end`tts` models. --- TTS/trainer.py | 147 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 104 insertions(+), 43 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index 3ac83601..a3e87e67 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -2,6 +2,7 @@ import importlib import logging +import multiprocessing import os import platform import re @@ -42,6 +43,8 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.models import setup_model as setup_vocoder_model +multiprocessing.set_start_method("fork") + if platform.system() != "Windows": # https://github.com/pytorch/pytorch/issues/973 import resource @@ -149,7 +152,6 @@ class Trainer: # set and initialize Pytorch runtime self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark) - if config is None: # parse config from console arguments config, output_path, _, c_logger, tb_logger = process_args(args) @@ -184,7 +186,7 @@ class Trainer: # init audio processor self.ap = AudioProcessor(**self.config.audio.to_dict()) - # load dataset samples + # load data samples # TODO: refactor this if "datasets" in self.config: # load data for `tts` models @@ -205,6 +207,10 @@ class Trainer: else: self.model = self.get_model(self.config) + # init multispeaker settings of the model + if hasattr(self.model, "init_multispeaker"): + self.model.init_multispeaker(self.config, self.data_train + self.data_eval) + # setup criterion self.criterion = self.get_criterion(self.model) @@ -274,9 +280,9 @@ class Trainer: """ # TODO: better model setup try: - model = setup_tts_model(config) - except ModuleNotFoundError: model = setup_vocoder_model(config) + except ModuleNotFoundError: + model = setup_tts_model(config) return model def restore_model( @@ -417,7 +423,7 @@ class Trainer: scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access config: Coqpit, optimizer_idx: int = None, - ) -> Tuple[Dict, Dict, int, torch.Tensor]: + ) -> Tuple[Dict, Dict, int]: """Perform a forward - backward pass and run the optimizer. Args: @@ -426,7 +432,7 @@ class Trainer: optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use. scaler (AMPScaler): AMP scaler. criterion (nn.Module): Model's criterion. - scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer. + scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer. config (Coqpit): Model config. optimizer_idx (int, optional): Target optimizer being used. Defaults to None. @@ -436,6 +442,7 @@ class Trainer: Returns: Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm. """ + step_start_time = time.time() # zero-out optimizer optimizer.zero_grad() @@ -448,11 +455,11 @@ class Trainer: # skip the rest if outputs is None: step_time = time.time() - step_start_time - return None, {}, step_time, 0 + return None, {}, step_time # check nan loss if torch.isnan(loss_dict["loss"]).any(): - raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.") + raise RuntimeError(f" > Detected NaN loss - {loss_dict}.") # set gradient clipping threshold if "grad_clip" in config and config.grad_clip is not None: @@ -463,7 +470,6 @@ class Trainer: else: grad_clip = 0.0 # meaning no gradient clipping - # TODO: compute grad norm if grad_clip <= 0: grad_norm = 0 @@ -474,15 +480,17 @@ class Trainer: with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss: scaled_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( - amp.master_params(optimizer), - grad_clip, + amp.master_params(optimizer), grad_clip, error_if_nonfinite=False ) else: # model optimizer step in mixed precision mode scaler.scale(loss_dict["loss"]).backward() - scaler.unscale_(optimizer) if grad_clip > 0: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False) + # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN + if torch.isnan(grad_norm) or torch.isinf(grad_norm): + grad_norm = 0 scale_prev = scaler.get_scale() scaler.step(optimizer) scaler.update() @@ -491,13 +499,13 @@ class Trainer: # main model optimizer step loss_dict["loss"].backward() if grad_clip > 0: - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False) optimizer.step() step_time = time.time() - step_start_time # setup lr - if scheduler is not None and update_lr_scheduler: + if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch: scheduler.step() # detach losses @@ -505,7 +513,9 @@ class Trainer: if optimizer_idx is not None: loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss") loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm - return outputs, loss_dict, step_time, grad_norm + else: + loss_dict["grad_norm"] = grad_norm + return outputs, loss_dict, step_time @staticmethod def _detach_loss_dict(loss_dict: Dict) -> Dict: @@ -544,11 +554,10 @@ class Trainer: # conteainers to hold model outputs and losses for each optimizer. outputs_per_optimizer = None - log_dict = {} loss_dict = {} if not isinstance(self.optimizer, list): # training with a single optimizer - outputs, loss_dict_new, step_time, grad_norm = self._optimize( + outputs, loss_dict_new, step_time = self._optimize( batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config ) loss_dict.update(loss_dict_new) @@ -560,25 +569,36 @@ class Trainer: criterion = self.criterion scaler = self.scaler[idx] if self.use_amp_scaler else None scheduler = self.scheduler[idx] - outputs, loss_dict_new, step_time, grad_norm = self._optimize( + outputs, loss_dict_new, step_time = self._optimize( batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx ) # skip the rest if the model returns None total_step_time += step_time outputs_per_optimizer[idx] = outputs + # merge loss_dicts from each optimizer + # rename duplicates with the optimizer idx # if None, model skipped this optimizer if loss_dict_new is not None: - loss_dict.update(loss_dict_new) + for k, v in loss_dict_new.items(): + if k in loss_dict: + loss_dict[f"{k}-{idx}"] = v + else: + loss_dict[k] = v + step_time = total_step_time outputs = outputs_per_optimizer - # update avg stats + # update avg runtime stats keep_avg_update = dict() - for key, value in log_dict.items(): - keep_avg_update["avg_" + key] = value keep_avg_update["avg_loader_time"] = loader_time keep_avg_update["avg_step_time"] = step_time self.keep_avg_train.update_values(keep_avg_update) + # update avg loss stats + update_eval_values = dict() + for key, value in loss_dict.items(): + update_eval_values["avg_" + key] = value + self.keep_avg_train.update_values(update_eval_values) + # print training progress if self.total_steps_done % self.config.print_step == 0: # log learning rates @@ -590,33 +610,27 @@ class Trainer: else: current_lr = self.optimizer.param_groups[0]["lr"] lrs = {"current_lr": current_lr} - log_dict.update(lrs) - if grad_norm > 0: - log_dict.update({"grad_norm": grad_norm}) + # log run-time stats - log_dict.update( + loss_dict.update( { "step_time": round(step_time, 4), "loader_time": round(loader_time, 4), } ) self.c_logger.print_train_step( - batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values + batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values ) if self.args.rank == 0: # Plot Training Iter Stats # reduce TB load and don't log every step if self.total_steps_done % self.config.tb_plot_step == 0: - iter_stats = log_dict - iter_stats.update(loss_dict) - self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats) + self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict) if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: if self.config.checkpoint: # checkpoint the model - model_loss = ( - loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"] - ) + target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train) save_checkpoint( self.config, self.model, @@ -625,7 +639,7 @@ class Trainer: self.total_steps_done, self.epochs_done, self.output_path, - model_loss=model_loss, + model_loss=target_avg_loss, ) # training visualizations figures, audios = None, None @@ -666,6 +680,14 @@ class Trainer: self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) if self.config.tb_model_param_stats: self.tb_logger.tb_model_weights(self.model, self.total_steps_done) + # scheduler step after the epoch + if self.scheduler is not None and self.config.scheduler_after_epoch: + if isinstance(self.scheduler, list): + for scheduler in self.scheduler: + if scheduler is not None: + scheduler.step() + else: + self.scheduler.step() @staticmethod def _model_eval_step( @@ -701,19 +723,22 @@ class Trainer: Tuple[Dict, Dict]: Model outputs and losses. """ with torch.no_grad(): - outputs_per_optimizer = None + outputs = [] loss_dict = {} if not isinstance(self.optimizer, list): outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion) else: - outputs_per_optimizer = [None] * len(self.optimizer) + outputs = [None] * len(self.optimizer) for idx, _ in enumerate(self.optimizer): criterion = self.criterion - outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx) - outputs_per_optimizer[idx] = outputs + outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx) + outputs[idx] = outputs_ + if loss_dict_new is not None: + loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss") loss_dict.update(loss_dict_new) - outputs = outputs_per_optimizer + + loss_dict = self._detach_loss_dict(loss_dict) # update avg stats update_eval_values = dict() @@ -764,6 +789,13 @@ class Trainer: """Run test and log the results. Test run must be defined by the model. Model must return figures and audios to be logged by the Tensorboard.""" if hasattr(self.model, "test_run"): + if self.eval_loader is None: + self.eval_loader = self.get_eval_dataloader( + self.ap, + self.data_eval, + verbose=True, + ) + if hasattr(self.eval_loader.dataset, "load_test_samples"): samples = self.eval_loader.dataset.load_test_samples(1) figures, audios = self.model.test_run(self.ap, samples, None) @@ -816,10 +848,33 @@ class Trainer: traceback.print_exc() sys.exit(1) + def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict: + """Pick the target loss to compare models""" + target_avg_loss = None + + # return if target loss defined in the model config + if "target_loss" in self.config and self.config.target_loss: + return keep_avg_target[f"avg_{self.config.target_loss}"] + + # take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers + if isinstance(self.optimizer, list): + target_avg_loss = 0 + for idx in range(len(self.optimizer)): + target_avg_loss += keep_avg_target[f"avg_loss_{idx}"] + target_avg_loss /= len(self.optimizer) + else: + target_avg_loss = keep_avg_target["avg_loss"] + return target_avg_loss + def save_best_model(self) -> None: """Save the best model. It only saves if the current target loss is smaller then the previous.""" + + # set the target loss to choose the best model + target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train) + + # save the model and update the best_loss self.best_loss = save_best_model( - self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], + target_loss_dict, self.best_loss, self.config, self.model, @@ -927,7 +982,7 @@ class Trainer: return criterion -def init_arguments(): +def getarguments(): train_config = TrainingArgs() parser = train_config.init_argparse(arg_prefix="") return parser @@ -1054,7 +1109,7 @@ def process_args(args, config=None): # if model characters are not set in the config file # save the default set to the config file for future # compatibility. - if config.has("characters_config"): + if config.has("characters") and config.characters is None: used_characters = parse_symbols() new_fields["characters"] = used_characters copy_model_files(config, experiment_path, new_fields) @@ -1065,6 +1120,12 @@ def process_args(args, config=None): return config, experiment_path, audio_path, c_logger, tb_logger +def init_arguments(): + train_config = TrainingArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + def init_training(argv: Union[List, Coqpit], config: Coqpit = None): """Initialization of a training run.""" if isinstance(argv, Coqpit):