mirror of https://github.com/coqui-ai/TTS.git
scaler fix for wavegrad and wavernn. Save and load scaler
This commit is contained in:
parent
d8511efa8f
commit
5a59467f34
|
@ -4,6 +4,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
|
@ -94,14 +95,11 @@ def train(model, criterion, optimizer,
|
||||||
c_logger.print_train_start()
|
c_logger.print_train_start()
|
||||||
# setup noise schedule
|
# setup noise schedule
|
||||||
noise_schedule = c['train_noise_schedule']
|
noise_schedule = c['train_noise_schedule']
|
||||||
|
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model.module.compute_noise_level(noise_schedule['num_steps'],
|
model.module.compute_noise_level(betas)
|
||||||
noise_schedule['min_val'],
|
|
||||||
noise_schedule['max_val'])
|
|
||||||
else:
|
else:
|
||||||
model.compute_noise_level(noise_schedule['num_steps'],
|
model.compute_noise_level(betas)
|
||||||
noise_schedule['min_val'],
|
|
||||||
noise_schedule['max_val'])
|
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -287,16 +285,13 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
# setup noise schedule and inference
|
# setup noise schedule and inference
|
||||||
noise_schedule = c['test_noise_schedule']
|
noise_schedule = c['test_noise_schedule']
|
||||||
|
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], noise_schedule['num_steps'])
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model.module.compute_noise_level(noise_schedule['num_steps'],
|
model.module.compute_noise_level(betas)
|
||||||
noise_schedule['min_val'],
|
|
||||||
noise_schedule['max_val'])
|
|
||||||
# compute voice
|
# compute voice
|
||||||
x_pred = model.module.inference(m)
|
x_pred = model.module.inference(m)
|
||||||
else:
|
else:
|
||||||
model.compute_noise_level(noise_schedule['num_steps'],
|
model.compute_noise_level(betas)
|
||||||
noise_schedule['min_val'],
|
|
||||||
noise_schedule['max_val'])
|
|
||||||
# compute voice
|
# compute voice
|
||||||
x_pred = model.inference(m)
|
x_pred = model.inference(m)
|
||||||
|
|
||||||
|
@ -363,6 +358,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||||
# NOTE: Not sure if necessary
|
# NOTE: Not sure if necessary
|
||||||
scheduler.optimizer = optimizer
|
scheduler.optimizer = optimizer
|
||||||
|
if "scaler" in checkpoint and c.mixed_precision:
|
||||||
|
print(" > Restoring AMP Scaler...")
|
||||||
|
scaler.load_state_dict(checkpoint["scaler"])
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# retore only matching layers.
|
# retore only matching layers.
|
||||||
print(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
|
|
|
@ -81,7 +81,7 @@ def format_data(data):
|
||||||
return x_input, mels, y_coarse
|
return x_input, mels, y_coarse
|
||||||
|
|
||||||
|
|
||||||
def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch):
|
||||||
# create train loader
|
# create train loader
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -94,7 +94,6 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
c_logger.print_train_start()
|
c_logger.print_train_start()
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
|
||||||
# train loop
|
# train loop
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -192,6 +191,7 @@ def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
|
||||||
epoch,
|
epoch,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
model_losses=loss_dict,
|
model_losses=loss_dict,
|
||||||
|
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# synthesize a full voice
|
# synthesize a full voice
|
||||||
|
@ -352,6 +352,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# setup model
|
# setup model
|
||||||
model_wavernn = setup_wavernn(c)
|
model_wavernn = setup_wavernn(c)
|
||||||
|
|
||||||
|
# setup amp scaler
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||||
|
|
||||||
# define train functions
|
# define train functions
|
||||||
if c.mode == "mold":
|
if c.mode == "mold":
|
||||||
criterion = discretized_mix_logistic_loss
|
criterion = discretized_mix_logistic_loss
|
||||||
|
@ -387,8 +390,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
print(" > Restoring Generator LR Scheduler...")
|
print(" > Restoring Generator LR Scheduler...")
|
||||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||||
scheduler.optimizer = optimizer
|
scheduler.optimizer = optimizer
|
||||||
# TODO: fix resetting restored optimizer lr
|
if "scaler" in checkpoint and c.mixed_precision:
|
||||||
# optimizer.load_state_dict(checkpoint["optimizer"])
|
print(" > Restoring AMP Scaler...")
|
||||||
|
scaler.load_state_dict(checkpoint["scaler"])
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# retore only matching layers.
|
# retore only matching layers.
|
||||||
print(" > Partial model initialization...")
|
print(" > Partial model initialization...")
|
||||||
|
@ -416,7 +420,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
for epoch in range(0, c.epochs):
|
for epoch in range(0, c.epochs):
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
_, global_step = train(model_wavernn, optimizer,
|
_, global_step = train(model_wavernn, optimizer,
|
||||||
criterion, scheduler, ap, global_step, epoch)
|
criterion, scheduler, scaler, ap, global_step, epoch)
|
||||||
eval_avg_loss_dict = evaluate(
|
eval_avg_loss_dict = evaluate(
|
||||||
model_wavernn, criterion, ap, global_step, epoch)
|
model_wavernn, criterion, ap, global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
|
@ -434,6 +438,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
epoch,
|
epoch,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
model_losses=eval_avg_loss_dict,
|
model_losses=eval_avg_loss_dict,
|
||||||
|
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue