This commit is contained in:
Eren Gölge 2021-10-14 14:39:45 +00:00
parent 21cc0517a3
commit e15bc157d8
2 changed files with 4 additions and 10 deletions

View File

@ -626,17 +626,13 @@ class Trainer:
# https://nvidia.github.io/apex/advanced.html?highlight=accumulate#backward-passes-with-multiple-optimizers
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
scaled_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
)
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), grad_clip)
else:
# model optimizer step in mixed precision mode
scaler.scale(loss_dict["loss"]).backward()
if grad_clip > 0:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.master_params(optimizer), grad_clip, error_if_nonfinite=False
)
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0
@ -648,7 +644,7 @@ class Trainer:
# main model optimizer step
loss_dict["loss"].backward()
if grad_clip > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
step_time = time.time() - step_start_time

View File

@ -29,9 +29,7 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
mel = ap.melspectrogram(y)
np.save(mel_path, mel)
if isinstance(config.mode, int):
quant = (
ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
)
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
np.save(quant_path, quant)