mirror of https://github.com/coqui-ai/TTS.git
Update `trainer.py`
Fix multi-speaker initialization of models. Add changes for end2end`tts` models.
This commit is contained in:
parent
b7f387b3dd
commit
bf562cf437
147
TTS/trainer.py
147
TTS/trainer.py
|
@ -2,6 +2,7 @@
|
|||
|
||||
import importlib
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
|
@ -42,6 +43,8 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa
|
|||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
|
||||
multiprocessing.set_start_method("fork")
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# https://github.com/pytorch/pytorch/issues/973
|
||||
import resource
|
||||
|
@ -149,7 +152,6 @@ class Trainer:
|
|||
|
||||
# set and initialize Pytorch runtime
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
||||
|
||||
if config is None:
|
||||
# parse config from console arguments
|
||||
config, output_path, _, c_logger, tb_logger = process_args(args)
|
||||
|
@ -184,7 +186,7 @@ class Trainer:
|
|||
# init audio processor
|
||||
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
||||
|
||||
# load dataset samples
|
||||
# load data samples
|
||||
# TODO: refactor this
|
||||
if "datasets" in self.config:
|
||||
# load data for `tts` models
|
||||
|
@ -205,6 +207,10 @@ class Trainer:
|
|||
else:
|
||||
self.model = self.get_model(self.config)
|
||||
|
||||
# init multispeaker settings of the model
|
||||
if hasattr(self.model, "init_multispeaker"):
|
||||
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
|
||||
|
||||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.model)
|
||||
|
||||
|
@ -274,9 +280,9 @@ class Trainer:
|
|||
"""
|
||||
# TODO: better model setup
|
||||
try:
|
||||
model = setup_tts_model(config)
|
||||
except ModuleNotFoundError:
|
||||
model = setup_vocoder_model(config)
|
||||
except ModuleNotFoundError:
|
||||
model = setup_tts_model(config)
|
||||
return model
|
||||
|
||||
def restore_model(
|
||||
|
@ -417,7 +423,7 @@ class Trainer:
|
|||
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
|
||||
config: Coqpit,
|
||||
optimizer_idx: int = None,
|
||||
) -> Tuple[Dict, Dict, int, torch.Tensor]:
|
||||
) -> Tuple[Dict, Dict, int]:
|
||||
"""Perform a forward - backward pass and run the optimizer.
|
||||
|
||||
Args:
|
||||
|
@ -426,7 +432,7 @@ class Trainer:
|
|||
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
|
||||
scaler (AMPScaler): AMP scaler.
|
||||
criterion (nn.Module): Model's criterion.
|
||||
scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
|
||||
config (Coqpit): Model config.
|
||||
optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
|
||||
|
||||
|
@ -436,6 +442,7 @@ class Trainer:
|
|||
Returns:
|
||||
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
|
||||
"""
|
||||
|
||||
step_start_time = time.time()
|
||||
# zero-out optimizer
|
||||
optimizer.zero_grad()
|
||||
|
@ -448,11 +455,11 @@ class Trainer:
|
|||
# skip the rest
|
||||
if outputs is None:
|
||||
step_time = time.time() - step_start_time
|
||||
return None, {}, step_time, 0
|
||||
return None, {}, step_time
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
||||
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
|
||||
|
||||
# set gradient clipping threshold
|
||||
if "grad_clip" in config and config.grad_clip is not None:
|
||||
|
@ -463,7 +470,6 @@ class Trainer:
|
|||
else:
|
||||
grad_clip = 0.0 # meaning no gradient clipping
|
||||
|
||||
# TODO: compute grad norm
|
||||
if grad_clip <= 0:
|
||||
grad_norm = 0
|
||||
|
||||
|
@ -474,15 +480,17 @@ class Trainer:
|
|||
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(optimizer),
|
||||
grad_clip,
|
||||
amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
|
||||
)
|
||||
else:
|
||||
# model optimizer step in mixed precision mode
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
if grad_clip > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||
grad_norm = 0
|
||||
scale_prev = scaler.get_scale()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
@ -491,13 +499,13 @@ class Trainer:
|
|||
# main model optimizer step
|
||||
loss_dict["loss"].backward()
|
||||
if grad_clip > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||
optimizer.step()
|
||||
|
||||
step_time = time.time() - step_start_time
|
||||
|
||||
# setup lr
|
||||
if scheduler is not None and update_lr_scheduler:
|
||||
if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
|
||||
scheduler.step()
|
||||
|
||||
# detach losses
|
||||
|
@ -505,7 +513,9 @@ class Trainer:
|
|||
if optimizer_idx is not None:
|
||||
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
|
||||
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
|
||||
return outputs, loss_dict, step_time, grad_norm
|
||||
else:
|
||||
loss_dict["grad_norm"] = grad_norm
|
||||
return outputs, loss_dict, step_time
|
||||
|
||||
@staticmethod
|
||||
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
||||
|
@ -544,11 +554,10 @@ class Trainer:
|
|||
|
||||
# conteainers to hold model outputs and losses for each optimizer.
|
||||
outputs_per_optimizer = None
|
||||
log_dict = {}
|
||||
loss_dict = {}
|
||||
if not isinstance(self.optimizer, list):
|
||||
# training with a single optimizer
|
||||
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
|
||||
outputs, loss_dict_new, step_time = self._optimize(
|
||||
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
|
||||
)
|
||||
loss_dict.update(loss_dict_new)
|
||||
|
@ -560,25 +569,36 @@ class Trainer:
|
|||
criterion = self.criterion
|
||||
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||
scheduler = self.scheduler[idx]
|
||||
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
|
||||
outputs, loss_dict_new, step_time = self._optimize(
|
||||
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
||||
)
|
||||
# skip the rest if the model returns None
|
||||
total_step_time += step_time
|
||||
outputs_per_optimizer[idx] = outputs
|
||||
# merge loss_dicts from each optimizer
|
||||
# rename duplicates with the optimizer idx
|
||||
# if None, model skipped this optimizer
|
||||
if loss_dict_new is not None:
|
||||
loss_dict.update(loss_dict_new)
|
||||
for k, v in loss_dict_new.items():
|
||||
if k in loss_dict:
|
||||
loss_dict[f"{k}-{idx}"] = v
|
||||
else:
|
||||
loss_dict[k] = v
|
||||
step_time = total_step_time
|
||||
outputs = outputs_per_optimizer
|
||||
|
||||
# update avg stats
|
||||
# update avg runtime stats
|
||||
keep_avg_update = dict()
|
||||
for key, value in log_dict.items():
|
||||
keep_avg_update["avg_" + key] = value
|
||||
keep_avg_update["avg_loader_time"] = loader_time
|
||||
keep_avg_update["avg_step_time"] = step_time
|
||||
self.keep_avg_train.update_values(keep_avg_update)
|
||||
|
||||
# update avg loss stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
self.keep_avg_train.update_values(update_eval_values)
|
||||
|
||||
# print training progress
|
||||
if self.total_steps_done % self.config.print_step == 0:
|
||||
# log learning rates
|
||||
|
@ -590,33 +610,27 @@ class Trainer:
|
|||
else:
|
||||
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||
lrs = {"current_lr": current_lr}
|
||||
log_dict.update(lrs)
|
||||
if grad_norm > 0:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
|
||||
# log run-time stats
|
||||
log_dict.update(
|
||||
loss_dict.update(
|
||||
{
|
||||
"step_time": round(step_time, 4),
|
||||
"loader_time": round(loader_time, 4),
|
||||
}
|
||||
)
|
||||
self.c_logger.print_train_step(
|
||||
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
||||
batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
|
||||
)
|
||||
|
||||
if self.args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load and don't log every step
|
||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
||||
iter_stats = log_dict
|
||||
iter_stats.update(loss_dict)
|
||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
|
||||
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
||||
if self.config.checkpoint:
|
||||
# checkpoint the model
|
||||
model_loss = (
|
||||
loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"]
|
||||
)
|
||||
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
|
||||
save_checkpoint(
|
||||
self.config,
|
||||
self.model,
|
||||
|
@ -625,7 +639,7 @@ class Trainer:
|
|||
self.total_steps_done,
|
||||
self.epochs_done,
|
||||
self.output_path,
|
||||
model_loss=model_loss,
|
||||
model_loss=target_avg_loss,
|
||||
)
|
||||
# training visualizations
|
||||
figures, audios = None, None
|
||||
|
@ -666,6 +680,14 @@ class Trainer:
|
|||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||
if self.config.tb_model_param_stats:
|
||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||
# scheduler step after the epoch
|
||||
if self.scheduler is not None and self.config.scheduler_after_epoch:
|
||||
if isinstance(self.scheduler, list):
|
||||
for scheduler in self.scheduler:
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
@staticmethod
|
||||
def _model_eval_step(
|
||||
|
@ -701,19 +723,22 @@ class Trainer:
|
|||
Tuple[Dict, Dict]: Model outputs and losses.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
outputs_per_optimizer = None
|
||||
outputs = []
|
||||
loss_dict = {}
|
||||
if not isinstance(self.optimizer, list):
|
||||
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
|
||||
else:
|
||||
outputs_per_optimizer = [None] * len(self.optimizer)
|
||||
outputs = [None] * len(self.optimizer)
|
||||
for idx, _ in enumerate(self.optimizer):
|
||||
criterion = self.criterion
|
||||
outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
|
||||
outputs_per_optimizer[idx] = outputs
|
||||
outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
|
||||
outputs[idx] = outputs_
|
||||
|
||||
if loss_dict_new is not None:
|
||||
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
|
||||
loss_dict.update(loss_dict_new)
|
||||
outputs = outputs_per_optimizer
|
||||
|
||||
loss_dict = self._detach_loss_dict(loss_dict)
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
|
@ -764,6 +789,13 @@ class Trainer:
|
|||
"""Run test and log the results. Test run must be defined by the model.
|
||||
Model must return figures and audios to be logged by the Tensorboard."""
|
||||
if hasattr(self.model, "test_run"):
|
||||
if self.eval_loader is None:
|
||||
self.eval_loader = self.get_eval_dataloader(
|
||||
self.ap,
|
||||
self.data_eval,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||
|
@ -816,10 +848,33 @@ class Trainer:
|
|||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
||||
"""Pick the target loss to compare models"""
|
||||
target_avg_loss = None
|
||||
|
||||
# return if target loss defined in the model config
|
||||
if "target_loss" in self.config and self.config.target_loss:
|
||||
return keep_avg_target[f"avg_{self.config.target_loss}"]
|
||||
|
||||
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
|
||||
if isinstance(self.optimizer, list):
|
||||
target_avg_loss = 0
|
||||
for idx in range(len(self.optimizer)):
|
||||
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
|
||||
target_avg_loss /= len(self.optimizer)
|
||||
else:
|
||||
target_avg_loss = keep_avg_target["avg_loss"]
|
||||
return target_avg_loss
|
||||
|
||||
def save_best_model(self) -> None:
|
||||
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
||||
|
||||
# set the target loss to choose the best model
|
||||
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
|
||||
|
||||
# save the model and update the best_loss
|
||||
self.best_loss = save_best_model(
|
||||
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
||||
target_loss_dict,
|
||||
self.best_loss,
|
||||
self.config,
|
||||
self.model,
|
||||
|
@ -927,7 +982,7 @@ class Trainer:
|
|||
return criterion
|
||||
|
||||
|
||||
def init_arguments():
|
||||
def getarguments():
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
@ -1054,7 +1109,7 @@ def process_args(args, config=None):
|
|||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if config.has("characters_config"):
|
||||
if config.has("characters") and config.characters is None:
|
||||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
|
@ -1065,6 +1120,12 @@ def process_args(args, config=None):
|
|||
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||
|
||||
|
||||
def init_arguments():
|
||||
train_config = TrainingArgs()
|
||||
parser = train_config.init_argparse(arg_prefix="")
|
||||
return parser
|
||||
|
||||
|
||||
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
||||
"""Initialization of a training run."""
|
||||
if isinstance(argv, Coqpit):
|
||||
|
|
Loading…
Reference in New Issue