mirror of https://github.com/coqui-ai/TTS.git
Small trainer refactoring
1. Use a single Gradscaler for all the optimizers 2. Save terminal logs to a file. In DDP mode, each worker creates `trainer_N_log.txt`. 3. Fixes to allow only the main worker (rank==0) writing to Tensorboard 4. Pass parameters owned by the target optimizer to the grad_clip_norm
This commit is contained in:
parent
3ab8cef99e
commit
5911eec3b1
|
@ -43,7 +43,7 @@ def main():
|
||||||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||||
command[-1] = "--rank={}".format(i)
|
command[-1] = "--rank={}".format(i)
|
||||||
# prevent stdout for processes with rank != 0
|
# prevent stdout for processes with rank != 0
|
||||||
stdout = None if i == 0 else open(os.devnull, "w")
|
stdout = None
|
||||||
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
|
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
print(command)
|
print(command)
|
||||||
|
|
102
TTS/trainer.py
102
TTS/trainer.py
|
@ -39,7 +39,7 @@ from TTS.utils.generic_utils import (
|
||||||
to_cuda,
|
to_cuda,
|
||||||
)
|
)
|
||||||
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
from TTS.utils.io import copy_model_files, load_fsspec, save_best_model, save_checkpoint
|
||||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_logger
|
from TTS.utils.logging import ConsoleLogger, TensorboardLogger, WandbLogger, init_dashboard_logger
|
||||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_available, setup_torch_training_env
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
|
@ -157,29 +157,33 @@ class Trainer:
|
||||||
- TPU training
|
- TPU training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# set and initialize Pytorch runtime
|
|
||||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
|
||||||
if config is None:
|
if config is None:
|
||||||
# parse config from console arguments
|
# parse config from console arguments
|
||||||
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
||||||
|
|
||||||
self.output_path = output_path
|
|
||||||
self.args = args
|
self.args = args
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.output_path = output_path
|
||||||
self.config.output_log_path = output_path
|
self.config.output_log_path = output_path
|
||||||
|
|
||||||
|
# setup logging
|
||||||
|
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||||
|
self._setup_logger_config(log_file)
|
||||||
|
|
||||||
|
# set and initialize Pytorch runtime
|
||||||
|
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
||||||
|
|
||||||
# init loggers
|
# init loggers
|
||||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||||
self.dashboard_logger = dashboard_logger
|
self.dashboard_logger = dashboard_logger
|
||||||
|
|
||||||
if self.dashboard_logger is None:
|
# only allow dashboard logging for the main process in DDP mode
|
||||||
self.dashboard_logger = init_logger(config)
|
if self.dashboard_logger is None and args.rank == 0:
|
||||||
|
self.dashboard_logger = init_dashboard_logger(config)
|
||||||
|
|
||||||
if not self.config.log_model_step:
|
if not self.config.log_model_step:
|
||||||
self.config.log_model_step = self.config.save_step
|
self.config.log_model_step = self.config.save_step
|
||||||
|
|
||||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
|
||||||
self._setup_logger_config(log_file)
|
|
||||||
|
|
||||||
self.total_steps_done = 0
|
self.total_steps_done = 0
|
||||||
self.epochs_done = 0
|
self.epochs_done = 0
|
||||||
self.restore_step = 0
|
self.restore_step = 0
|
||||||
|
@ -253,9 +257,9 @@ class Trainer:
|
||||||
if self.use_apex:
|
if self.use_apex:
|
||||||
self.scaler = None
|
self.scaler = None
|
||||||
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
|
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1")
|
||||||
if isinstance(self.optimizer, list):
|
# if isinstance(self.optimizer, list):
|
||||||
self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
|
# self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer)
|
||||||
else:
|
# else:
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
else:
|
else:
|
||||||
self.scaler = None
|
self.scaler = None
|
||||||
|
@ -325,7 +329,8 @@ class Trainer:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||||
checkpoint = load_fsspec(restore_path)
|
# checkpoint = load_fsspec(restore_path)
|
||||||
|
checkpoint = torch.load(restore_path, map_location="cpu")
|
||||||
try:
|
try:
|
||||||
print(" > Restoring Model...")
|
print(" > Restoring Model...")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
@ -408,6 +413,19 @@ class Trainer:
|
||||||
batch[k] = to_cuda(v)
|
batch[k] = to_cuda(v)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def master_params(optimizer: torch.optim.Optimizer):
|
||||||
|
"""Generator over parameters owned by the optimizer.
|
||||||
|
|
||||||
|
Used to select parameters used by the optimizer for gradient clipping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: Target optimizer.
|
||||||
|
"""
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
yield p
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_train_step(
|
def _model_train_step(
|
||||||
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
|
batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None
|
||||||
|
@ -465,6 +483,8 @@ class Trainer:
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
# zero-out optimizer
|
# zero-out optimizer
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# forward pass and loss computation
|
||||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||||
if optimizer_idx is not None:
|
if optimizer_idx is not None:
|
||||||
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
|
outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx)
|
||||||
|
@ -476,9 +496,9 @@ class Trainer:
|
||||||
step_time = time.time() - step_start_time
|
step_time = time.time() - step_start_time
|
||||||
return None, {}, step_time
|
return None, {}, step_time
|
||||||
|
|
||||||
# check nan loss
|
# # check nan loss
|
||||||
if torch.isnan(loss_dict["loss"]).any():
|
# if torch.isnan(loss_dict["loss"]).any():
|
||||||
raise RuntimeError(f" > NaN loss detected - {loss_dict}")
|
# raise RuntimeError(f" > NaN loss detected - {loss_dict}")
|
||||||
|
|
||||||
# set gradient clipping threshold
|
# set gradient clipping threshold
|
||||||
if "grad_clip" in config and config.grad_clip is not None:
|
if "grad_clip" in config and config.grad_clip is not None:
|
||||||
|
@ -496,6 +516,8 @@ class Trainer:
|
||||||
update_lr_scheduler = True
|
update_lr_scheduler = True
|
||||||
if self.use_amp_scaler:
|
if self.use_amp_scaler:
|
||||||
if self.use_apex:
|
if self.use_apex:
|
||||||
|
# TODO: verify AMP use for GAN training in TTS
|
||||||
|
# https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
|
||||||
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
@ -506,7 +528,7 @@ class Trainer:
|
||||||
scaler.scale(loss_dict["loss"]).backward()
|
scaler.scale(loss_dict["loss"]).backward()
|
||||||
if grad_clip > 0:
|
if grad_clip > 0:
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip, error_if_nonfinite=False)
|
||||||
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||||
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||||
grad_norm = 0
|
grad_norm = 0
|
||||||
|
@ -586,7 +608,8 @@ class Trainer:
|
||||||
total_step_time = 0
|
total_step_time = 0
|
||||||
for idx, optimizer in enumerate(self.optimizer):
|
for idx, optimizer in enumerate(self.optimizer):
|
||||||
criterion = self.criterion
|
criterion = self.criterion
|
||||||
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
# scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||||
|
scaler = self.scaler
|
||||||
scheduler = self.scheduler[idx]
|
scheduler = self.scheduler[idx]
|
||||||
outputs, loss_dict_new, step_time = self._optimize(
|
outputs, loss_dict_new, step_time = self._optimize(
|
||||||
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
||||||
|
@ -677,9 +700,10 @@ class Trainer:
|
||||||
if audios is not None:
|
if audios is not None:
|
||||||
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
self.dashboard_logger.train_audios(self.total_steps_done, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
|
self.dashboard_logger.flush()
|
||||||
|
|
||||||
self.total_steps_done += 1
|
self.total_steps_done += 1
|
||||||
self.callbacks.on_train_step_end()
|
self.callbacks.on_train_step_end()
|
||||||
self.dashboard_logger.flush()
|
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_epoch(self) -> None:
|
def train_epoch(self) -> None:
|
||||||
|
@ -866,7 +890,7 @@ class Trainer:
|
||||||
self.keep_avg_train = KeepAverage()
|
self.keep_avg_train = KeepAverage()
|
||||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||||
self.epochs_done = epoch
|
self.epochs_done = epoch
|
||||||
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
|
||||||
self.train_epoch()
|
self.train_epoch()
|
||||||
if self.config.run_eval:
|
if self.config.run_eval:
|
||||||
self.eval_epoch()
|
self.eval_epoch()
|
||||||
|
@ -883,6 +907,7 @@ class Trainer:
|
||||||
"""Where the ✨️magic✨️ happens..."""
|
"""Where the ✨️magic✨️ happens..."""
|
||||||
try:
|
try:
|
||||||
self._fit()
|
self._fit()
|
||||||
|
if self.args.rank == 0:
|
||||||
self.dashboard_logger.finish()
|
self.dashboard_logger.finish()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
self.callbacks.on_keyboard_interrupt()
|
self.callbacks.on_keyboard_interrupt()
|
||||||
|
@ -892,6 +917,7 @@ class Trainer:
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
# finish the wandb run and sync data
|
# finish the wandb run and sync data
|
||||||
|
if self.args.rank == 0:
|
||||||
self.dashboard_logger.finish()
|
self.dashboard_logger.finish()
|
||||||
# stop without error signal
|
# stop without error signal
|
||||||
try:
|
try:
|
||||||
|
@ -942,18 +968,29 @@ class Trainer:
|
||||||
keep_after=self.config.keep_after,
|
keep_after=self.config.keep_after,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
def _setup_logger_config(self, log_file: str) -> None:
|
||||||
def _setup_logger_config(log_file: str) -> None:
|
"""Write log strings to a file and print logs to the terminal.
|
||||||
handlers = [logging.StreamHandler()]
|
TODO: Causes formatting issues in pdb debugging."""
|
||||||
|
|
||||||
# Only add a log file if the output location is local due to poor
|
class Logger(object):
|
||||||
# support for writing logs to file-like objects.
|
def __init__(self, print_to_terminal=True):
|
||||||
parsed_url = urlparse(log_file)
|
self.print_to_terminal = print_to_terminal
|
||||||
if not parsed_url.scheme or parsed_url.scheme == "file":
|
self.terminal = sys.stdout
|
||||||
schemeless_path = os.path.join(parsed_url.netloc, parsed_url.path)
|
self.log = open(log_file, "a")
|
||||||
handlers.append(logging.FileHandler(schemeless_path))
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="", handlers=handlers)
|
def write(self, message):
|
||||||
|
if self.print_to_terminal:
|
||||||
|
self.terminal.write(message)
|
||||||
|
self.log.write(message)
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
# this flush method is needed for python 3 compatibility.
|
||||||
|
# this handles the flush command by doing nothing.
|
||||||
|
# you might want to specify some extra behavior here.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# don't let processes rank > 0 write to the terminal
|
||||||
|
sys.stdout = Logger(self.args.rank == 0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_apex_available() -> bool:
|
def _is_apex_available() -> bool:
|
||||||
|
@ -1149,8 +1186,6 @@ def process_args(args, config=None):
|
||||||
config = register_config(config_base.model)()
|
config = register_config(config_base.model)()
|
||||||
# override values from command-line args
|
# override values from command-line args
|
||||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||||
if config.mixed_precision:
|
|
||||||
print(" > Mixed precision mode is ON")
|
|
||||||
experiment_path = args.continue_path
|
experiment_path = args.continue_path
|
||||||
if not experiment_path:
|
if not experiment_path:
|
||||||
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
experiment_path = get_experiment_folder_path(config.output_path, config.run_name)
|
||||||
|
@ -1170,8 +1205,7 @@ def process_args(args, config=None):
|
||||||
used_characters = parse_symbols()
|
used_characters = parse_symbols()
|
||||||
new_fields["characters"] = used_characters
|
new_fields["characters"] = used_characters
|
||||||
copy_model_files(config, experiment_path, new_fields)
|
copy_model_files(config, experiment_path, new_fields)
|
||||||
|
dashboard_logger = init_dashboard_logger(config)
|
||||||
dashboard_logger = init_logger(config)
|
|
||||||
c_logger = ConsoleLogger()
|
c_logger = ConsoleLogger()
|
||||||
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
return config, experiment_path, audio_path, c_logger, dashboard_logger
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue