Update `trainer.py`

Fix multi-speaker initialization of models. Add changes for end2end`tts`
models.
This commit is contained in:
Eren Gölge 2021-08-07 21:30:07 +00:00
parent b7f387b3dd
commit bf562cf437
1 changed files with 104 additions and 43 deletions

View File

@ -2,6 +2,7 @@
import importlib
import logging
import multiprocessing
import os
import platform
import re
@ -42,6 +43,8 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model
multiprocessing.set_start_method("fork")
if platform.system() != "Windows":
# https://github.com/pytorch/pytorch/issues/973
import resource
@ -149,7 +152,6 @@ class Trainer:
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
if config is None:
# parse config from console arguments
config, output_path, _, c_logger, tb_logger = process_args(args)
@ -184,7 +186,7 @@ class Trainer:
# init audio processor
self.ap = AudioProcessor(**self.config.audio.to_dict())
# load dataset samples
# load data samples
# TODO: refactor this
if "datasets" in self.config:
# load data for `tts` models
@ -205,6 +207,10 @@ class Trainer:
else:
self.model = self.get_model(self.config)
# init multispeaker settings of the model
if hasattr(self.model, "init_multispeaker"):
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
# setup criterion
self.criterion = self.get_criterion(self.model)
@ -274,9 +280,9 @@ class Trainer:
"""
# TODO: better model setup
try:
model = setup_tts_model(config)
except ModuleNotFoundError:
model = setup_vocoder_model(config)
except ModuleNotFoundError:
model = setup_tts_model(config)
return model
def restore_model(
@ -417,7 +423,7 @@ class Trainer:
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
config: Coqpit,
optimizer_idx: int = None,
) -> Tuple[Dict, Dict, int, torch.Tensor]:
) -> Tuple[Dict, Dict, int]:
"""Perform a forward - backward pass and run the optimizer.
Args:
@ -426,7 +432,7 @@ class Trainer:
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
scaler (AMPScaler): AMP scaler.
criterion (nn.Module): Model's criterion.
scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer.
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
config (Coqpit): Model config.
optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
@ -436,6 +442,7 @@ class Trainer:
Returns:
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
"""
step_start_time = time.time()
# zero-out optimizer
optimizer.zero_grad()
@ -448,11 +455,11 @@ class Trainer:
# skip the rest
if outputs is None:
step_time = time.time() - step_start_time
return None, {}, step_time, 0
return None, {}, step_time
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
# set gradient clipping threshold
if "grad_clip" in config and config.grad_clip is not None:
@ -463,7 +470,6 @@ class Trainer:
else:
grad_clip = 0.0 # meaning no gradient clipping
# TODO: compute grad norm
if grad_clip <= 0:
grad_norm = 0
@ -474,15 +480,17 @@ class Trainer:
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer),
grad_clip,
amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
)
else:
# model optimizer step in mixed precision mode
scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer)
if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0
scale_prev = scaler.get_scale()
scaler.step(optimizer)
scaler.update()
@ -491,13 +499,13 @@ class Trainer:
# main model optimizer step
loss_dict["loss"].backward()
if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
optimizer.step()
step_time = time.time() - step_start_time
# setup lr
if scheduler is not None and update_lr_scheduler:
if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
scheduler.step()
# detach losses
@ -505,7 +513,9 @@ class Trainer:
if optimizer_idx is not None:
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
return outputs, loss_dict, step_time, grad_norm
else:
loss_dict["grad_norm"] = grad_norm
return outputs, loss_dict, step_time
@staticmethod
def _detach_loss_dict(loss_dict: Dict) -> Dict:
@ -544,11 +554,10 @@ class Trainer:
# conteainers to hold model outputs and losses for each optimizer.
outputs_per_optimizer = None
log_dict = {}
loss_dict = {}
if not isinstance(self.optimizer, list):
# training with a single optimizer
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
)
loss_dict.update(loss_dict_new)
@ -560,25 +569,36 @@ class Trainer:
criterion = self.criterion
scaler = self.scaler[idx] if self.use_amp_scaler else None
scheduler = self.scheduler[idx]
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
)
# skip the rest if the model returns None
total_step_time += step_time
outputs_per_optimizer[idx] = outputs
# merge loss_dicts from each optimizer
# rename duplicates with the optimizer idx
# if None, model skipped this optimizer
if loss_dict_new is not None:
loss_dict.update(loss_dict_new)
for k, v in loss_dict_new.items():
if k in loss_dict:
loss_dict[f"{k}-{idx}"] = v
else:
loss_dict[k] = v
step_time = total_step_time
outputs = outputs_per_optimizer
# update avg stats
# update avg runtime stats
keep_avg_update = dict()
for key, value in log_dict.items():
keep_avg_update["avg_" + key] = value
keep_avg_update["avg_loader_time"] = loader_time
keep_avg_update["avg_step_time"] = step_time
self.keep_avg_train.update_values(keep_avg_update)
# update avg loss stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
self.keep_avg_train.update_values(update_eval_values)
# print training progress
if self.total_steps_done % self.config.print_step == 0:
# log learning rates
@ -590,33 +610,27 @@ class Trainer:
else:
current_lr = self.optimizer.param_groups[0]["lr"]
lrs = {"current_lr": current_lr}
log_dict.update(lrs)
if grad_norm > 0:
log_dict.update({"grad_norm": grad_norm})
# log run-time stats
log_dict.update(
loss_dict.update(
{
"step_time": round(step_time, 4),
"loader_time": round(loader_time, 4),
}
)
self.c_logger.print_train_step(
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
)
if self.args.rank == 0:
# Plot Training Iter Stats
# reduce TB load and don't log every step
if self.total_steps_done % self.config.tb_plot_step == 0:
iter_stats = log_dict
iter_stats.update(loss_dict)
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
if self.config.checkpoint:
# checkpoint the model
model_loss = (
loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"]
)
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
save_checkpoint(
self.config,
self.model,
@ -625,7 +639,7 @@ class Trainer:
self.total_steps_done,
self.epochs_done,
self.output_path,
model_loss=model_loss,
model_loss=target_avg_loss,
)
# training visualizations
figures, audios = None, None
@ -666,6 +680,14 @@ class Trainer:
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
# scheduler step after the epoch
if self.scheduler is not None and self.config.scheduler_after_epoch:
if isinstance(self.scheduler, list):
for scheduler in self.scheduler:
if scheduler is not None:
scheduler.step()
else:
self.scheduler.step()
@staticmethod
def _model_eval_step(
@ -701,19 +723,22 @@ class Trainer:
Tuple[Dict, Dict]: Model outputs and losses.
"""
with torch.no_grad():
outputs_per_optimizer = None
outputs = []
loss_dict = {}
if not isinstance(self.optimizer, list):
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
else:
outputs_per_optimizer = [None] * len(self.optimizer)
outputs = [None] * len(self.optimizer)
for idx, _ in enumerate(self.optimizer):
criterion = self.criterion
outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
outputs_per_optimizer[idx] = outputs
outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
outputs[idx] = outputs_
if loss_dict_new is not None:
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
loss_dict.update(loss_dict_new)
outputs = outputs_per_optimizer
loss_dict = self._detach_loss_dict(loss_dict)
# update avg stats
update_eval_values = dict()
@ -764,6 +789,13 @@ class Trainer:
"""Run test and log the results. Test run must be defined by the model.
Model must return figures and audios to be logged by the Tensorboard."""
if hasattr(self.model, "test_run"):
if self.eval_loader is None:
self.eval_loader = self.get_eval_dataloader(
self.ap,
self.data_eval,
verbose=True,
)
if hasattr(self.eval_loader.dataset, "load_test_samples"):
samples = self.eval_loader.dataset.load_test_samples(1)
figures, audios = self.model.test_run(self.ap, samples, None)
@ -816,10 +848,33 @@ class Trainer:
traceback.print_exc()
sys.exit(1)
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
"""Pick the target loss to compare models"""
target_avg_loss = None
# return if target loss defined in the model config
if "target_loss" in self.config and self.config.target_loss:
return keep_avg_target[f"avg_{self.config.target_loss}"]
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
if isinstance(self.optimizer, list):
target_avg_loss = 0
for idx in range(len(self.optimizer)):
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
target_avg_loss /= len(self.optimizer)
else:
target_avg_loss = keep_avg_target["avg_loss"]
return target_avg_loss
def save_best_model(self) -> None:
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
# set the target loss to choose the best model
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
# save the model and update the best_loss
self.best_loss = save_best_model(
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
target_loss_dict,
self.best_loss,
self.config,
self.model,
@ -927,7 +982,7 @@ class Trainer:
return criterion
def init_arguments():
def getarguments():
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
@ -1054,7 +1109,7 @@ def process_args(args, config=None):
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if config.has("characters_config"):
if config.has("characters") and config.characters is None:
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
@ -1065,6 +1120,12 @@ def process_args(args, config=None):
return config, experiment_path, audio_path, c_logger, tb_logger
def init_arguments():
train_config = TrainingArgs()
parser = train_config.init_argparse(arg_prefix="")
return parser
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
"""Initialization of a training run."""
if isinstance(argv, Coqpit):