mirror of https://github.com/coqui-ai/TTS.git
Small trainer refactoring
1. Use a single Gradscaler for all the optimizers 2. Save terminal logs to a file. In DDP mode, each worker creates `trainer_N_log.txt`. 3. Fixes to allow only the main worker (rank==0) writing to Tensorboard 4. Pass parameters owned by the target optimizer to the grad_clip_norm
This commit is contained in:
parent
3ab8cef99e
commit
5911eec3b1
|
@ -43,7 +43,7 @@ def main():
|
|||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||
command[-1] = "--rank={}".format(i)
|
||||
# prevent stdout for processes with rank != 0
|
||||
stdout = None if i == 0 else open(os.devnull, "w")
|
||||
stdout = None
|
||||
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
|
||||
processes.append(p)
|
||||
print(command)
|
||||
|
|
108
TTS/trainer.py
108
TTS/trainer.py
|
@ -39,7 +39,7 @@ from TTS.utils.generic_utils import (
|
|||
to_cuda,
|
||||
)
|
||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
|
@ -157,29 +157,33 @@ class Trainer:
|
|||
- TPU training
|
||||
"""
|
||||
|
||||
# set and initialize Pytorch runtime
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
||||
if config is None:
|
||||
# parse config from console arguments
|
||||
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
||||
|
||||
self.output_path = output_path
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.output_path = output_path
|
||||
self.config.output_log_path = output_path
|
||||
|
||||
# setup logging
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
||||
# set and initialize Pytorch runtime
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
||||
|
||||
# init loggers
|
||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||
self.dashboard_logger = dashboard_logger
|
||||
|
||||
if self.dashboard_logger is None:
|
||||
self.dashboard_logger = init_logger(config)
|
||||
# only allow dashboard logging for the main process in DDP mode
|
||||
if self.dashboard_logger is None and args.rank == 0:
|
||||
self.dashboard_logger = init_dashboard_logger(config)
|
||||
|
||||
if not self.config.log_model_step:
|
||||
self.config.log_model_step = self.config.save_step
|
||||
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
||||
self.total_steps_done = 0
|
||||
self.epochs_done = 0
|
||||
self.restore_step = 0
|
||||
|
@ -253,10 +257,10 @@ class Trainer:
|
|||
if self.use_apex:
|
||||
self.scaler = None
|
||||
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
|
||||
if isinstance(self.optimizer, list):
|
||||
self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
|
||||
else:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
# if isinstance(self.optimizer, list):
|
||||
# self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
|
||||
# else:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
else:
|
||||
self.scaler = None
|
||||
|
||||
|
@ -325,7 +329,8 @@ class Trainer:
|
|||
return obj
|
||||
|
||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||
checkpoint = load_fsspec(restore_path)
|
||||
# checkpoint = load_fsspec(restore_path)
|
||||
checkpoint = torch.load(restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
@ -408,6 +413,19 @@ class Trainer:
|
|||
batch[k] = to_cuda(v)
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def master_params(optimizer: torch.optim.Optimizer):
|
||||
"""Generator over parameters owned by the optimizer.
|
||||
|
||||
Used to select parameters used by the optimizer for gradient clipping.
|
||||
|
||||
Args:
|
||||
optimizer: Target optimizer.
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
yield p
|
||||
|
||||
@staticmethod
|
||||
def _model_train_step(
|
||||
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
|
||||
|
@ -465,6 +483,8 @@ class Trainer:
|
|||
step_start_time = time.time()
|
||||
# zero-out optimizer
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass and loss computation
|
||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||
if optimizer_idx is not None:
|
||||
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
|
||||
|
@ -476,9 +496,9 @@ class Trainer:
|
|||
step_time = time.time() - step_start_time
|
||||
return None, {}, step_time
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(f" > NaN loss detected - {loss_dict}")
|
||||
# # check nan loss
|
||||
# if torch.isnan(loss_dict["loss"]).any():
|
||||
# raise RuntimeError(f" > NaN loss detected - {loss_dict}")
|
||||
|
||||
# set gradient clipping threshold
|
||||
if "grad_clip" in config and config.grad_clip is not None:
|
||||
|
@ -496,6 +516,8 @@ class Trainer:
|
|||
update_lr_scheduler = True
|
||||
if self.use_amp_scaler:
|
||||
if self.use_apex:
|
||||
# TODO: verify AMP use for GAN training in TTS
|
||||
# https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
|
||||
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
@ -506,7 +528,7 @@ class Trainer:
|
|||
scaler.scale(loss_dict["loss"]).backward()
|
||||
if grad_clip > 0:
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip, error_if_nonfinite=False)
|
||||
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||
grad_norm = 0
|
||||
|
@ -586,7 +608,8 @@ class Trainer:
|
|||
total_step_time = 0
|
||||
for idx, optimizer in enumerate(self.optimizer):
|
||||
criterion = self.criterion
|
||||
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||
# scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||
scaler = self.scaler
|
||||
scheduler = self.scheduler[idx]
|
||||
outputs, loss_dict_new, step_time = self._optimize(
|
||||
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
||||
|
@ -677,9 +700,10 @@ class Trainer:
|
|||
if audios is not None:
|
||||
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||
|
||||
self.dashboard_logger.flush()
|
||||
|
||||
self.total_steps_done += 1
|
||||
self.callbacks.on_train_step_end()
|
||||
self.dashboard_logger.flush()
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_epoch(self) -> None:
|
||||
|
@ -866,7 +890,7 @@ class Trainer:
|
|||
self.keep_avg_train = KeepAverage()
|
||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||
self.epochs_done = epoch
|
||||
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
||||
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
|
||||
self.train_epoch()
|
||||
if self.config.run_eval:
|
||||
self.eval_epoch()
|
||||
|
@ -883,7 +907,8 @@ class Trainer:
|
|||
"""Where the ✨️magic✨️ happens..."""
|
||||
try:
|
||||
self._fit()
|
||||
self.dashboard_logger.finish()
|
||||
if self.args.rank == 0:
|
||||
self.dashboard_logger.finish()
|
||||
except KeyboardInterrupt:
|
||||
self.callbacks.on_keyboard_interrupt()
|
||||
# if the output folder is empty remove the run.
|
||||
|
@ -892,7 +917,8 @@ class Trainer:
|
|||
if self.num_gpus > 1:
|
||||
dist.destroy_process_group()
|
||||
# finish the wandb run and sync data
|
||||
self.dashboard_logger.finish()
|
||||
if self.args.rank == 0:
|
||||
self.dashboard_logger.finish()
|
||||
# stop without error signal
|
||||
try:
|
||||
sys.exit(0)
|
||||
|
@ -942,18 +968,29 @@ class Trainer:
|
|||
keep_after=self.config.keep_after,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _setup_logger_config(log_file: str) -> None:
|
||||
handlers = [logging.StreamHandler()]
|
||||
def _setup_logger_config(self, log_file: str) -> None:
|
||||
"""Write log strings to a file and print logs to the terminal.
|
||||
TODO: Causes formatting issues in pdb debugging."""
|
||||
|
||||
# Only add a log file if the output location is local due to poor
|
||||
# support for writing logs to file-like objects.
|
||||
parsed_url = urlparse(log_file)
|
||||
if not parsed_url.scheme or parsed_url.scheme == "file":
|
||||
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
|
||||
handlers.append(logging.FileHandler(schemeless_path))
|
||||
class Logger(object):
|
||||
def __init__(self, print_to_terminal=True):
|
||||
self.print_to_terminal = print_to_terminal
|
||||
self.terminal = sys.stdout
|
||||
self.log = open(log_file, "a")
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
|
||||
def write(self, message):
|
||||
if self.print_to_terminal:
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
|
||||
def flush(self):
|
||||
# this flush method is needed for python 3 compatibility.
|
||||
# this handles the flush command by doing nothing.
|
||||
# you might want to specify some extra behavior here.
|
||||
pass
|
||||
|
||||
# don't let processes rank > 0 write to the terminal
|
||||
sys.stdout = Logger(self.args.rank == 0)
|
||||
|
||||
@staticmethod
|
||||
def _is_apex_available() -> bool:
|
||||
|
@ -1149,8 +1186,6 @@ def process_args(args, config=None):
|
|||
config = register_config(config_base.model)()
|
||||
# override values from command-line args
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
if config.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||
|
@ -1170,8 +1205,7 @@ def process_args(args, config=None):
|
|||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
|
||||
dashboard_logger = init_logger(config)
|
||||
dashboard_logger = init_dashboard_logger(config)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||
|
||||
|
|
Loading…
Reference in New Issue