Small trainer refactoring

1. Use a single Gradscaler for all the optimizers
2. Save terminal logs to a file. In DDP mode, each worker creates `trainer_N_log.txt`.
3. Fixes to allow only the main worker (rank==0) writing to Tensorboard
4. Pass parameters owned by the target optimizer to the grad_clip_norm
This commit is contained in:
Eren Gölge 2021-08-26 17:08:58 +00:00
parent 3ab8cef99e
commit 5911eec3b1
2 changed files with 72 additions and 38 deletions

View File

@ -43,7 +43,7 @@ def main():
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[-1] = "--rank={}".format(i)
# prevent stdout for processes with rank != 0
stdout = None if i == 0 else open(os.devnull, "w")
stdout = None
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
processes.append(p)
print(command)

View File

@ -39,7 +39,7 @@ from TTS.utils.generic_utils import (
to_cuda,
)
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.models import setup_model as setup_vocoder_model
@ -157,29 +157,33 @@ class Trainer:
- TPU training
"""
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
if config is None:
# parse config from console arguments
config, output_path, _, c_logger, dashboard_logger = process_args(args)
self.output_path = output_path
self.args = args
self.config = config
self.output_path = output_path
self.config.output_log_path = output_path
# setup logging
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
# init loggers
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
self.dashboard_logger = dashboard_logger
if self.dashboard_logger is None:
self.dashboard_logger = init_logger(config)
# only allow dashboard logging for the main process in DDP mode
if self.dashboard_logger is None and args.rank == 0:
self.dashboard_logger = init_dashboard_logger(config)
if not self.config.log_model_step:
self.config.log_model_step = self.config.save_step
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
self.total_steps_done = 0
self.epochs_done = 0
self.restore_step = 0
@ -253,10 +257,10 @@ class Trainer:
if self.use_apex:
self.scaler = None
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
if isinstance(self.optimizer, list):
self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
else:
self.scaler = torch.cuda.amp.GradScaler()
# if isinstance(self.optimizer, list):
# self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
# else:
self.scaler = torch.cuda.amp.GradScaler()
else:
self.scaler = None
@ -325,7 +329,8 @@ class Trainer:
return obj
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = load_fsspec(restore_path)
# checkpoint = load_fsspec(restore_path)
checkpoint = torch.load(restore_path, map_location="cpu")
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
@ -408,6 +413,19 @@ class Trainer:
batch[k] = to_cuda(v)
return batch
@staticmethod
def master_params(optimizer: torch.optim.Optimizer):
"""Generator over parameters owned by the optimizer.
Used to select parameters used by the optimizer for gradient clipping.
Args:
optimizer: Target optimizer.
"""
for group in optimizer.param_groups:
for p in group['params']:
yield p
@staticmethod
def _model_train_step(
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
@ -465,6 +483,8 @@ class Trainer:
step_start_time = time.time()
# zero-out optimizer
optimizer.zero_grad()
# forward pass and loss computation
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
if optimizer_idx is not None:
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
@ -476,9 +496,9 @@ class Trainer:
step_time = time.time() - step_start_time
return None, {}, step_time
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f" > NaN loss detected - {loss_dict}")
# # check nan loss
# if torch.isnan(loss_dict["loss"]).any():
# raise RuntimeError(f" > NaN loss detected - {loss_dict}")
# set gradient clipping threshold
if "grad_clip" in config and config.grad_clip is not None:
@ -496,6 +516,8 @@ class Trainer:
update_lr_scheduler = True
if self.use_amp_scaler:
if self.use_apex:
# TODO: verify AMP use for GAN training in TTS
# https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
@ -506,7 +528,7 @@ class Trainer:
scaler.scale(loss_dict["loss"]).backward()
if grad_clip > 0:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip, error_if_nonfinite=False)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0
@ -586,7 +608,8 @@ class Trainer:
total_step_time = 0
for idx, optimizer in enumerate(self.optimizer):
criterion = self.criterion
scaler = self.scaler[idx] if self.use_amp_scaler else None
# scaler = self.scaler[idx] if self.use_amp_scaler else None
scaler = self.scaler
scheduler = self.scheduler[idx]
outputs, loss_dict_new, step_time = self._optimize(
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
@ -677,9 +700,10 @@ class Trainer:
if audios is not None:
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
self.dashboard_logger.flush()
self.total_steps_done += 1
self.callbacks.on_train_step_end()
self.dashboard_logger.flush()
return outputs, loss_dict
def train_epoch(self) -> None:
@ -866,7 +890,7 @@ class Trainer:
self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
self.epochs_done = epoch
self.c_logger.print_epoch_start(epoch, self.config.epochs)
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
@ -883,7 +907,8 @@ class Trainer:
"""Where the ✨magic✨ happens..."""
try:
self._fit()
self.dashboard_logger.finish()
if self.args.rank == 0:
self.dashboard_logger.finish()
except KeyboardInterrupt:
self.callbacks.on_keyboard_interrupt()
# if the output folder is empty remove the run.
@ -892,7 +917,8 @@ class Trainer:
if self.num_gpus > 1:
dist.destroy_process_group()
# finish the wandb run and sync data
self.dashboard_logger.finish()
if self.args.rank == 0:
self.dashboard_logger.finish()
# stop without error signal
try:
sys.exit(0)
@ -942,18 +968,29 @@ class Trainer:
keep_after=self.config.keep_after,
)
@staticmethod
def _setup_logger_config(log_file: str) -> None:
handlers = [logging.StreamHandler()]
def _setup_logger_config(self, log_file: str) -> None:
"""Write log strings to a file and print logs to the terminal.
TODO: Causes formatting issues in pdb debugging."""
# Only add a log file if the output location is local due to poor
# support for writing logs to file-like objects.
parsed_url = urlparse(log_file)
if not parsed_url.scheme or parsed_url.scheme == "file":
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
handlers.append(logging.FileHandler(schemeless_path))
class Logger(object):
def __init__(self, print_to_terminal=True):
self.print_to_terminal = print_to_terminal
self.terminal = sys.stdout
self.log = open(log_file, "a")
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
def write(self, message):
if self.print_to_terminal:
self.terminal.write(message)
self.log.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
# don't let processes rank > 0 write to the terminal
sys.stdout = Logger(self.args.rank == 0)
@staticmethod
def _is_apex_available() -> bool:
@ -1149,8 +1186,6 @@ def process_args(args, config=None):
config = register_config(config_base.model)()
# override values from command-line args
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
if config.mixed_precision:
print(" > Mixed precision mode is ON")
experiment_path = args.continue_path
if not experiment_path:
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
@ -1170,8 +1205,7 @@ def process_args(args, config=None):
used_characters = parse_symbols()
new_fields["characters"] = used_characters
copy_model_files(config, experiment_path, new_fields)
dashboard_logger = init_logger(config)
dashboard_logger = init_dashboard_logger(config)
c_logger = ConsoleLogger()
return config, experiment_path, audio_path, c_logger, dashboard_logger