scaler fix for wavegrad and wavernn. Save and load scaler

This commit is contained in:
erogol 2020-11-14 13:00:35 +01:00
parent d8511efa8f
commit 5a59467f34
2 changed files with 20 additions and 17 deletions

View File

@ -4,6 +4,7 @@ import os
import sys
import time
import traceback
import numpy as np
import torch
# DISTRIBUTED
@ -94,14 +95,11 @@ def train(model, criterion, optimizer,
c_logger.print_train_start()
# setup noise schedule
noise_schedule = c['train_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
if hasattr(model, 'module'):
model.module.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
model.module.compute_noise_level(betas)
else:
model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
model.compute_noise_level(betas)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -287,16 +285,13 @@ def evaluate(model, criterion, ap, global_step, epoch):
# setup noise schedule and inference
noise_schedule = c['test_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
if hasattr(model, 'module'):
model.module.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
model.module.compute_noise_level(betas)
# compute voice
x_pred = model.module.inference(m)
else:
model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
model.compute_noise_level(betas)
# compute voice
x_pred = model.inference(m)
@ -363,6 +358,9 @@ def main(args): # pylint: disable=redefined-outer-name
scheduler.load_state_dict(checkpoint['scheduler'])
# NOTE: Not sure if necessary
scheduler.optimizer = optimizer
if "scaler" in checkpoint and c.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
except RuntimeError:
# retore only matching layers.
print(" > Partial model initialization...")

View File

@ -81,7 +81,7 @@ def format_data(data):
return x_input, mels, y_coarse
def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch):
# create train loader
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train()
@ -94,7 +94,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
scaler = torch.cuda.amp.GradScaler()
# train loop
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -192,6 +191,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
epoch,
OUT_PATH,
model_losses=loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None
)
# synthesize a full voice
@ -352,6 +352,9 @@ def main(args): # pylint: disable=redefined-outer-name
# setup model
model_wavernn = setup_wavernn(c)
# setup amp scaler
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
# define train functions
if c.mode == "mold":
criterion = discretized_mix_logistic_loss
@ -387,8 +390,9 @@ def main(args): # pylint: disable=redefined-outer-name
print(" > Restoring Generator LR Scheduler...")
scheduler.load_state_dict(checkpoint["scheduler"])
scheduler.optimizer = optimizer
# TODO: fix resetting restored optimizer lr
# optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and c.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
except RuntimeError:
# retore only matching layers.
print(" > Partial model initialization...")
@ -416,7 +420,7 @@ def main(args): # pylint: disable=redefined-outer-name
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model_wavernn, optimizer,
criterion, scheduler, ap, global_step, epoch)
criterion, scheduler, scaler, ap, global_step, epoch)
eval_avg_loss_dict = evaluate(
model_wavernn, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
@ -434,6 +438,7 @@ def main(args): # pylint: disable=redefined-outer-name
epoch,
OUT_PATH,
model_losses=eval_avg_loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None
)