Update `trainer.py`

Fix multi-speaker initialization of models. Add changes for end2end`tts`
models.
This commit is contained in:
Eren Gölge 2021-08-07 21:30:07 +00:00
parent b7f387b3dd
commit bf562cf437
1 changed files with 104 additions and 43 deletions

View File

@ -2,6 +2,7 @@
import importlib import importlib
import logging import logging
import multiprocessing
import os import os
import platform import platform
import re import re
@ -42,6 +43,8 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.models import setup_model as setup_vocoder_model
multiprocessing.set_start_method("fork")
if platform.system() != "Windows": if platform.system() != "Windows":
# https://github.com/pytorch/pytorch/issues/973 # https://github.com/pytorch/pytorch/issues/973
import resource import resource
@ -149,7 +152,6 @@ class Trainer:
# set and initialize Pytorch runtime # set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark) self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
if config is None: if config is None:
# parse config from console arguments # parse config from console arguments
config, output_path, _, c_logger, tb_logger = process_args(args) config, output_path, _, c_logger, tb_logger = process_args(args)
@ -184,7 +186,7 @@ class Trainer:
# init audio processor # init audio processor
self.ap = AudioProcessor(**self.config.audio.to_dict()) self.ap = AudioProcessor(**self.config.audio.to_dict())
# load dataset samples # load data samples
# TODO: refactor this # TODO: refactor this
if "datasets" in self.config: if "datasets" in self.config:
# load data for `tts` models # load data for `tts` models
@ -205,6 +207,10 @@ class Trainer:
else: else:
self.model = self.get_model(self.config) self.model = self.get_model(self.config)
# init multispeaker settings of the model
if hasattr(self.model, "init_multispeaker"):
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
# setup criterion # setup criterion
self.criterion = self.get_criterion(self.model) self.criterion = self.get_criterion(self.model)
@ -274,9 +280,9 @@ class Trainer:
""" """
# TODO: better model setup # TODO: better model setup
try: try:
model = setup_tts_model(config)
except ModuleNotFoundError:
model = setup_vocoder_model(config) model = setup_vocoder_model(config)
except ModuleNotFoundError:
model = setup_tts_model(config)
return model return model
def restore_model( def restore_model(
@ -417,7 +423,7 @@ class Trainer:
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
config: Coqpit, config: Coqpit,
optimizer_idx: int = None, optimizer_idx: int = None,
) -> Tuple[Dict, Dict, int, torch.Tensor]: ) -> Tuple[Dict, Dict, int]:
"""Perform a forward - backward pass and run the optimizer. """Perform a forward - backward pass and run the optimizer.
Args: Args:
@ -426,7 +432,7 @@ class Trainer:
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use. optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
scaler (AMPScaler): AMP scaler. scaler (AMPScaler): AMP scaler.
criterion (nn.Module): Model's criterion. criterion (nn.Module): Model's criterion.
scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer. scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
config (Coqpit): Model config. config (Coqpit): Model config.
optimizer_idx (int, optional): Target optimizer being used. Defaults to None. optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
@ -436,6 +442,7 @@ class Trainer:
Returns: Returns:
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm. Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
""" """
step_start_time = time.time() step_start_time = time.time()
# zero-out optimizer # zero-out optimizer
optimizer.zero_grad() optimizer.zero_grad()
@ -448,11 +455,11 @@ class Trainer:
# skip the rest # skip the rest
if outputs is None: if outputs is None:
step_time = time.time() - step_start_time step_time = time.time() - step_start_time
return None, {}, step_time, 0 return None, {}, step_time
# check nan loss # check nan loss
if torch.isnan(loss_dict["loss"]).any(): if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.") raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
# set gradient clipping threshold # set gradient clipping threshold
if "grad_clip" in config and config.grad_clip is not None: if "grad_clip" in config and config.grad_clip is not None:
@ -463,7 +470,6 @@ class Trainer:
else: else:
grad_clip = 0.0 # meaning no gradient clipping grad_clip = 0.0 # meaning no gradient clipping
# TODO: compute grad norm
if grad_clip <= 0: if grad_clip <= 0:
grad_norm = 0 grad_norm = 0
@ -474,15 +480,17 @@ class Trainer:
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss: with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
grad_clip,
) )
else: else:
# model optimizer step in mixed precision mode # model optimizer step in mixed precision mode
scaler.scale(loss_dict["loss"]).backward() scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer)
if grad_clip > 0: if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0
scale_prev = scaler.get_scale() scale_prev = scaler.get_scale()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
@ -491,13 +499,13 @@ class Trainer:
# main model optimizer step # main model optimizer step
loss_dict["loss"].backward() loss_dict["loss"].backward()
if grad_clip > 0: if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
optimizer.step() optimizer.step()
step_time = time.time() - step_start_time step_time = time.time() - step_start_time
# setup lr # setup lr
if scheduler is not None and update_lr_scheduler: if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
scheduler.step() scheduler.step()
# detach losses # detach losses
@ -505,7 +513,9 @@ class Trainer:
if optimizer_idx is not None: if optimizer_idx is not None:
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss") loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
return outputs, loss_dict, step_time, grad_norm else:
loss_dict["grad_norm"] = grad_norm
return outputs, loss_dict, step_time
@staticmethod @staticmethod
def _detach_loss_dict(loss_dict: Dict) -> Dict: def _detach_loss_dict(loss_dict: Dict) -> Dict:
@ -544,11 +554,10 @@ class Trainer:
# conteainers to hold model outputs and losses for each optimizer. # conteainers to hold model outputs and losses for each optimizer.
outputs_per_optimizer = None outputs_per_optimizer = None
log_dict = {}
loss_dict = {} loss_dict = {}
if not isinstance(self.optimizer, list): if not isinstance(self.optimizer, list):
# training with a single optimizer # training with a single optimizer
outputs, loss_dict_new, step_time, grad_norm = self._optimize( outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
) )
loss_dict.update(loss_dict_new) loss_dict.update(loss_dict_new)
@ -560,25 +569,36 @@ class Trainer:
criterion = self.criterion criterion = self.criterion
scaler = self.scaler[idx] if self.use_amp_scaler else None scaler = self.scaler[idx] if self.use_amp_scaler else None
scheduler = self.scheduler[idx] scheduler = self.scheduler[idx]
outputs, loss_dict_new, step_time, grad_norm = self._optimize( outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
) )
# skip the rest if the model returns None # skip the rest if the model returns None
total_step_time += step_time total_step_time += step_time
outputs_per_optimizer[idx] = outputs outputs_per_optimizer[idx] = outputs
# merge loss_dicts from each optimizer
# rename duplicates with the optimizer idx
# if None, model skipped this optimizer # if None, model skipped this optimizer
if loss_dict_new is not None: if loss_dict_new is not None:
loss_dict.update(loss_dict_new) for k, v in loss_dict_new.items():
if k in loss_dict:
loss_dict[f"{k}-{idx}"] = v
else:
loss_dict[k] = v
step_time = total_step_time
outputs = outputs_per_optimizer outputs = outputs_per_optimizer
# update avg stats # update avg runtime stats
keep_avg_update = dict() keep_avg_update = dict()
for key, value in log_dict.items():
keep_avg_update["avg_" + key] = value
keep_avg_update["avg_loader_time"] = loader_time keep_avg_update["avg_loader_time"] = loader_time
keep_avg_update["avg_step_time"] = step_time keep_avg_update["avg_step_time"] = step_time
self.keep_avg_train.update_values(keep_avg_update) self.keep_avg_train.update_values(keep_avg_update)
# update avg loss stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_train.update_values(update_eval_values)
# print training progress # print training progress
if self.total_steps_done % self.config.print_step == 0: if self.total_steps_done % self.config.print_step == 0:
# log learning rates # log learning rates
@ -590,33 +610,27 @@ class Trainer:
else: else:
current_lr = self.optimizer.param_groups[0]["lr"] current_lr = self.optimizer.param_groups[0]["lr"]
lrs = {"current_lr": current_lr} lrs = {"current_lr": current_lr}
log_dict.update(lrs)
if grad_norm > 0:
log_dict.update({"grad_norm": grad_norm})
# log run-time stats # log run-time stats
log_dict.update( loss_dict.update(
{ {
"step_time": round(step_time, 4), "step_time": round(step_time, 4),
"loader_time": round(loader_time, 4), "loader_time": round(loader_time, 4),
} }
) )
self.c_logger.print_train_step( self.c_logger.print_train_step(
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
) )
if self.args.rank == 0: if self.args.rank == 0:
# Plot Training Iter Stats # Plot Training Iter Stats
# reduce TB load and don't log every step # reduce TB load and don't log every step
if self.total_steps_done % self.config.tb_plot_step == 0: if self.total_steps_done % self.config.tb_plot_step == 0:
iter_stats = log_dict self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
if self.config.checkpoint: if self.config.checkpoint:
# checkpoint the model # checkpoint the model
model_loss = ( target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"]
)
save_checkpoint( save_checkpoint(
self.config, self.config,
self.model, self.model,
@ -625,7 +639,7 @@ class Trainer:
self.total_steps_done, self.total_steps_done,
self.epochs_done, self.epochs_done,
self.output_path, self.output_path,
model_loss=model_loss, model_loss=target_avg_loss,
) )
# training visualizations # training visualizations
figures, audios = None, None figures, audios = None, None
@ -666,6 +680,14 @@ class Trainer:
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.tb_model_param_stats: if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done) self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
# scheduler step after the epoch
if self.scheduler is not None and self.config.scheduler_after_epoch:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
if scheduler is not None:
scheduler.step()
else:
self.scheduler.step()
@staticmethod @staticmethod
def _model_eval_step( def _model_eval_step(
@ -701,19 +723,22 @@ class Trainer:
Tuple[Dict, Dict]: Model outputs and losses. Tuple[Dict, Dict]: Model outputs and losses.
""" """
with torch.no_grad(): with torch.no_grad():
outputs_per_optimizer = None outputs = []
loss_dict = {} loss_dict = {}
if not isinstance(self.optimizer, list): if not isinstance(self.optimizer, list):
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion) outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
else: else:
outputs_per_optimizer = [None] * len(self.optimizer) outputs = [None] * len(self.optimizer)
for idx, _ in enumerate(self.optimizer): for idx, _ in enumerate(self.optimizer):
criterion = self.criterion criterion = self.criterion
outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx) outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
outputs_per_optimizer[idx] = outputs outputs[idx] = outputs_
if loss_dict_new is not None: if loss_dict_new is not None:
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
loss_dict.update(loss_dict_new) loss_dict.update(loss_dict_new)
outputs = outputs_per_optimizer
loss_dict = self._detach_loss_dict(loss_dict)
# update avg stats # update avg stats
update_eval_values = dict() update_eval_values = dict()
@ -764,6 +789,13 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model. """Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard.""" Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"): if hasattr(self.model, "test_run"):
if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader(
self.ap,
self.data_eval,
verbose=True,
)
if hasattr(self.eval_loader.dataset, "load_test_samples"): if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1) samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None) figures, audios = self.model.test_run(self.ap, samples, None)
@ -816,10 +848,33 @@ class Trainer:
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""
target_avg_loss = None
# return if target loss defined in the model config
if "target_loss" in self.config and self.config.target_loss:
return keep_avg_target[f"avg_{self.config.target_loss}"]
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
target_avg_loss = 0
for idx in range(len(self.optimizer)):
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
else:
target_avg_loss = keep_avg_target["avg_loss"]
return target_avg_loss
def save_best_model(self) -> None: def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous.""" """Save the best model. It only saves if the current target loss is smaller then the previous."""
# set the target loss to choose the best model
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
# save the model and update the best_loss
self.best_loss = save_best_model( self.best_loss = save_best_model(
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], target_loss_dict,
self.best_loss, self.best_loss,
self.config, self.config,
self.model, self.model,
@ -927,7 +982,7 @@ class Trainer:
return criterion return criterion
def init_arguments(): def getarguments():
train_config = TrainingArgs() train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="") parser = train_config.init_argparse(arg_prefix="")
return parser return parser
@ -1054,7 +1109,7 @@ def process_args(args, config=None):
# if model characters are not set in the config file # if model characters are not set in the config file
# save the default set to the config file for future # save the default set to the config file for future
# compatibility. # compatibility.
if config.has("characters_config"): if config.has("characters") and config.characters is None:
used_characters = parse_symbols() used_characters = parse_symbols()
new_fields["characters"] = used_characters new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields) copy_model_files(config, experiment_path, new_fields)
@ -1065,6 +1120,12 @@ def process_args(args, config=None):
return config, experiment_path, audio_path, c_logger, tb_logger return config, experiment_path, audio_path, c_logger, tb_logger
def init_arguments():
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
def init_training(argv: Union[List, Coqpit], config: Coqpit = None): def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
"""Initialization of a training run.""" """Initialization of a training run."""
if isinstance(argv, Coqpit): if isinstance(argv, Coqpit):