mirror of https://github.com/coqui-ai/TTS.git
scaler fix for wavegrad and wavernn. Save and load scaler
This commit is contained in:
parent
d8511efa8f
commit
5a59467f34
|
@ -4,6 +4,7 @@ import os
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
# DISTRIBUTED
|
||||
|
@ -94,14 +95,11 @@ def train(model, criterion, optimizer,
|
|||
c_logger.print_train_start()
|
||||
# setup noise schedule
|
||||
noise_schedule = c['train_noise_schedule']
|
||||
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
|
||||
if hasattr(model, 'module'):
|
||||
model.module.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
model.module.compute_noise_level(betas)
|
||||
else:
|
||||
model.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
model.compute_noise_level(betas)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -287,16 +285,13 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = c['test_noise_schedule']
|
||||
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
|
||||
if hasattr(model, 'module'):
|
||||
model.module.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
model.module.compute_noise_level(betas)
|
||||
# compute voice
|
||||
x_pred = model.module.inference(m)
|
||||
else:
|
||||
model.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
model.compute_noise_level(betas)
|
||||
# compute voice
|
||||
x_pred = model.inference(m)
|
||||
|
||||
|
@ -363,6 +358,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
# NOTE: Not sure if necessary
|
||||
scheduler.optimizer = optimizer
|
||||
if "scaler" in checkpoint and c.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except RuntimeError:
|
||||
# retore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
|
|
|
@ -81,7 +81,7 @@ def format_data(data):
|
|||
return x_input, mels, y_coarse
|
||||
|
||||
|
||||
def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||
def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch):
|
||||
# create train loader
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model.train()
|
||||
|
@ -94,7 +94,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
# train loop
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
@ -192,6 +191,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
|||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
|
||||
# synthesize a full voice
|
||||
|
@ -352,6 +352,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# setup model
|
||||
model_wavernn = setup_wavernn(c)
|
||||
|
||||
# setup amp scaler
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
|
||||
# define train functions
|
||||
if c.mode == "mold":
|
||||
criterion = discretized_mix_logistic_loss
|
||||
|
@ -387,8 +390,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print(" > Restoring Generator LR Scheduler...")
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
scheduler.optimizer = optimizer
|
||||
# TODO: fix resetting restored optimizer lr
|
||||
# optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scaler" in checkpoint and c.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except RuntimeError:
|
||||
# retore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
|
@ -416,7 +420,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model_wavernn, optimizer,
|
||||
criterion, scheduler, ap, global_step, epoch)
|
||||
criterion, scheduler, scaler, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(
|
||||
model_wavernn, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
|
@ -434,6 +438,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue