mirror of https://github.com/coqui-ai/TTS.git
453 lines
16 KiB
Python
453 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""Trains WaveGrad vocoder models."""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import numpy as np
|
|
|
|
import torch
|
|
# DISTRIBUTED
|
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
|
from torch.optim import Adam
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from TTS.utils.arguments import parse_arguments, process_args
|
|
from TTS.utils.audio import AudioProcessor
|
|
from TTS.utils.distribute import init_distributed
|
|
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
|
remove_experiment_folder, set_init_dict)
|
|
from TTS.utils.training import setup_torch_training_env
|
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
|
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
|
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
|
|
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
|
|
|
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
|
|
|
|
|
def setup_loader(ap, is_val=False, verbose=False):
|
|
if is_val and not c.run_eval:
|
|
loader = None
|
|
else:
|
|
dataset = WaveGradDataset(ap=ap,
|
|
items=eval_data if is_val else train_data,
|
|
seq_len=c.seq_len,
|
|
hop_len=ap.hop_length,
|
|
pad_short=c.pad_short,
|
|
conv_pad=c.conv_pad,
|
|
is_training=not is_val,
|
|
return_segments=True,
|
|
use_noise_augment=False,
|
|
use_cache=c.use_cache,
|
|
verbose=verbose)
|
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
|
loader = DataLoader(dataset,
|
|
batch_size=c.batch_size,
|
|
shuffle=num_gpus <= 1,
|
|
drop_last=False,
|
|
sampler=sampler,
|
|
num_workers=c.num_val_loader_workers
|
|
if is_val else c.num_loader_workers,
|
|
pin_memory=False)
|
|
|
|
|
|
return loader
|
|
|
|
|
|
def format_data(data):
|
|
# return a whole audio segment
|
|
m, x = data
|
|
x = x.unsqueeze(1)
|
|
if use_cuda:
|
|
m = m.cuda(non_blocking=True)
|
|
x = x.cuda(non_blocking=True)
|
|
return m, x
|
|
|
|
|
|
def format_test_data(data):
|
|
# return a whole audio segment
|
|
m, x = data
|
|
m = m[None, ...]
|
|
x = x[None, None, ...]
|
|
if use_cuda:
|
|
m = m.cuda(non_blocking=True)
|
|
x = x.cuda(non_blocking=True)
|
|
return m, x
|
|
|
|
|
|
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
|
epoch):
|
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
|
model.train()
|
|
epoch_time = 0
|
|
keep_avg = KeepAverage()
|
|
if use_cuda:
|
|
batch_n_iter = int(
|
|
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
|
else:
|
|
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
|
end_time = time.time()
|
|
c_logger.print_train_start()
|
|
# setup noise schedule
|
|
noise_schedule = c['train_noise_schedule']
|
|
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'],
|
|
noise_schedule['num_steps'])
|
|
if hasattr(model, 'module'):
|
|
model.module.compute_noise_level(betas)
|
|
else:
|
|
model.compute_noise_level(betas)
|
|
for num_iter, data in enumerate(data_loader):
|
|
start_time = time.time()
|
|
|
|
# format data
|
|
m, x = format_data(data)
|
|
loader_time = time.time() - end_time
|
|
|
|
global_step += 1
|
|
|
|
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
|
# compute noisy input
|
|
if hasattr(model, 'module'):
|
|
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
|
else:
|
|
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
|
|
|
# forward pass
|
|
noise_hat = model(x_noisy, m, noise_scale)
|
|
|
|
# compute losses
|
|
loss = criterion(noise, noise_hat)
|
|
loss_wavegrad_dict = {'wavegrad_loss': loss}
|
|
|
|
# check nan loss
|
|
if torch.isnan(loss).any():
|
|
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# backward pass with loss scaling
|
|
if c.mixed_precision:
|
|
scaler.scale(loss).backward()
|
|
scaler.unscale_(optimizer)
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
|
c.clip_grad)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
loss.backward()
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
|
c.clip_grad)
|
|
optimizer.step()
|
|
|
|
# schedule update
|
|
if scheduler is not None:
|
|
scheduler.step()
|
|
|
|
# disconnect loss values
|
|
loss_dict = dict()
|
|
for key, value in loss_wavegrad_dict.items():
|
|
if isinstance(value, int):
|
|
loss_dict[key] = value
|
|
else:
|
|
loss_dict[key] = value.item()
|
|
|
|
# epoch/step timing
|
|
step_time = time.time() - start_time
|
|
epoch_time += step_time
|
|
|
|
# get current learning rates
|
|
current_lr = list(optimizer.param_groups)[0]['lr']
|
|
|
|
# update avg stats
|
|
update_train_values = dict()
|
|
for key, value in loss_dict.items():
|
|
update_train_values['avg_' + key] = value
|
|
update_train_values['avg_loader_time'] = loader_time
|
|
update_train_values['avg_step_time'] = step_time
|
|
keep_avg.update_values(update_train_values)
|
|
|
|
# print training stats
|
|
if global_step % c.print_step == 0:
|
|
log_dict = {
|
|
'step_time': [step_time, 2],
|
|
'loader_time': [loader_time, 4],
|
|
"current_lr": current_lr,
|
|
"grad_norm": grad_norm.item()
|
|
}
|
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
|
log_dict, loss_dict, keep_avg.avg_values)
|
|
|
|
if args.rank == 0:
|
|
# plot step stats
|
|
if global_step % 10 == 0:
|
|
iter_stats = {
|
|
"lr": current_lr,
|
|
"grad_norm": grad_norm.item(),
|
|
"step_time": step_time
|
|
}
|
|
iter_stats.update(loss_dict)
|
|
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
|
|
|
# save checkpoint
|
|
if global_step % c.save_step == 0:
|
|
if c.checkpoint:
|
|
# save model
|
|
save_checkpoint(
|
|
model,
|
|
optimizer,
|
|
scheduler,
|
|
None,
|
|
None,
|
|
None,
|
|
global_step,
|
|
epoch,
|
|
OUT_PATH,
|
|
model_losses=loss_dict,
|
|
scaler=scaler.state_dict() if c.mixed_precision else None
|
|
)
|
|
|
|
end_time = time.time()
|
|
|
|
# print epoch stats
|
|
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
|
|
|
# Plot Training Epoch Stats
|
|
epoch_stats = {"epoch_time": epoch_time}
|
|
epoch_stats.update(keep_avg.avg_values)
|
|
if args.rank == 0:
|
|
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
|
# TODO: plot model stats
|
|
if c.tb_model_param_stats and args.rank == 0:
|
|
tb_logger.tb_model_weights(model, global_step)
|
|
return keep_avg.avg_values, global_step
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, criterion, ap, global_step, epoch):
|
|
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
|
model.eval()
|
|
epoch_time = 0
|
|
keep_avg = KeepAverage()
|
|
end_time = time.time()
|
|
c_logger.print_eval_start()
|
|
for num_iter, data in enumerate(data_loader):
|
|
start_time = time.time()
|
|
|
|
# format data
|
|
m, x = format_data(data)
|
|
loader_time = time.time() - end_time
|
|
|
|
global_step += 1
|
|
|
|
# compute noisy input
|
|
if hasattr(model, 'module'):
|
|
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
|
else:
|
|
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
|
|
|
|
|
# forward pass
|
|
noise_hat = model(x_noisy, m, noise_scale)
|
|
|
|
# compute losses
|
|
loss = criterion(noise, noise_hat)
|
|
loss_wavegrad_dict = {'wavegrad_loss': loss}
|
|
|
|
|
|
loss_dict = dict()
|
|
for key, value in loss_wavegrad_dict.items():
|
|
if isinstance(value, (int, float)):
|
|
loss_dict[key] = value
|
|
else:
|
|
loss_dict[key] = value.item()
|
|
|
|
step_time = time.time() - start_time
|
|
epoch_time += step_time
|
|
|
|
# update avg stats
|
|
update_eval_values = dict()
|
|
for key, value in loss_dict.items():
|
|
update_eval_values['avg_' + key] = value
|
|
update_eval_values['avg_loader_time'] = loader_time
|
|
update_eval_values['avg_step_time'] = step_time
|
|
keep_avg.update_values(update_eval_values)
|
|
|
|
# print eval stats
|
|
if c.print_eval:
|
|
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
|
|
|
if args.rank == 0:
|
|
data_loader.dataset.return_segments = False
|
|
samples = data_loader.dataset.load_test_samples(1)
|
|
m, x = format_test_data(samples[0])
|
|
|
|
# setup noise schedule and inference
|
|
noise_schedule = c['test_noise_schedule']
|
|
betas = np.linspace(noise_schedule['min_val'],
|
|
noise_schedule['max_val'],
|
|
noise_schedule['num_steps'])
|
|
if hasattr(model, 'module'):
|
|
model.module.compute_noise_level(betas)
|
|
# compute voice
|
|
x_pred = model.module.inference(m)
|
|
else:
|
|
model.compute_noise_level(betas)
|
|
# compute voice
|
|
x_pred = model.inference(m)
|
|
|
|
# compute spectrograms
|
|
figures = plot_results(x_pred, x, ap, global_step, 'eval')
|
|
tb_logger.tb_eval_figures(global_step, figures)
|
|
|
|
# Sample audio
|
|
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
|
|
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
|
c.audio["sample_rate"])
|
|
|
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
|
data_loader.dataset.return_segments = True
|
|
|
|
return keep_avg.avg_values
|
|
|
|
|
|
def main(args): # pylint: disable=redefined-outer-name
|
|
# pylint: disable=global-variable-undefined
|
|
global train_data, eval_data
|
|
print(f" > Loading wavs from: {c.data_path}")
|
|
if c.feature_path is not None:
|
|
print(f" > Loading features from: {c.feature_path}")
|
|
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
|
|
c.eval_split_size)
|
|
else:
|
|
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
|
|
|
# setup audio processor
|
|
ap = AudioProcessor(**c.audio)
|
|
|
|
# DISTRUBUTED
|
|
if num_gpus > 1:
|
|
init_distributed(args.rank, num_gpus, args.group_id,
|
|
c.distributed["backend"], c.distributed["url"])
|
|
|
|
# setup models
|
|
model = setup_generator(c)
|
|
|
|
# scaler for mixed_precision
|
|
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
|
|
|
# setup optimizers
|
|
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
|
|
|
# schedulers
|
|
scheduler = None
|
|
if 'lr_scheduler' in c:
|
|
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
|
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
|
|
|
# setup criterion
|
|
criterion = torch.nn.L1Loss().cuda()
|
|
|
|
if use_cuda:
|
|
model.cuda()
|
|
criterion.cuda()
|
|
|
|
if args.restore_path:
|
|
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
|
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
|
try:
|
|
print(" > Restoring Model...")
|
|
model.load_state_dict(checkpoint['model'])
|
|
print(" > Restoring Optimizer...")
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
if 'scheduler' in checkpoint:
|
|
print(" > Restoring LR Scheduler...")
|
|
scheduler.load_state_dict(checkpoint['scheduler'])
|
|
# NOTE: Not sure if necessary
|
|
scheduler.optimizer = optimizer
|
|
if "scaler" in checkpoint and c.mixed_precision:
|
|
print(" > Restoring AMP Scaler...")
|
|
scaler.load_state_dict(checkpoint["scaler"])
|
|
except RuntimeError:
|
|
# retore only matching layers.
|
|
print(" > Partial model initialization...")
|
|
model_dict = model.state_dict()
|
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
|
model.load_state_dict(model_dict)
|
|
del model_dict
|
|
|
|
# reset lr if not countinuining training.
|
|
for group in optimizer.param_groups:
|
|
group['lr'] = c.lr
|
|
|
|
print(" > Model restored from step %d" % checkpoint['step'],
|
|
flush=True)
|
|
args.restore_step = checkpoint['step']
|
|
else:
|
|
args.restore_step = 0
|
|
|
|
# DISTRUBUTED
|
|
if num_gpus > 1:
|
|
model = DDP_th(model, device_ids=[args.rank])
|
|
|
|
num_params = count_parameters(model)
|
|
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
|
|
|
|
if args.restore_step == 0 or not args.best_path:
|
|
best_loss = float('inf')
|
|
print(" > Starting with inf best loss.")
|
|
else:
|
|
print(" > Restoring best loss from "
|
|
f"{os.path.basename(args.best_path)} ...")
|
|
best_loss = torch.load(args.best_path,
|
|
map_location='cpu')['model_loss']
|
|
print(f" > Starting with loaded last best loss {best_loss}.")
|
|
keep_all_best = c.get('keep_all_best', False)
|
|
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
|
|
|
global_step = args.restore_step
|
|
for epoch in range(0, c.epochs):
|
|
c_logger.print_epoch_start(epoch, c.epochs)
|
|
_, global_step = train(model, criterion, optimizer, scheduler, scaler,
|
|
ap, global_step, epoch)
|
|
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
|
target_loss = eval_avg_loss_dict[c.target_loss]
|
|
best_loss = save_best_model(
|
|
target_loss,
|
|
best_loss,
|
|
model,
|
|
optimizer,
|
|
scheduler,
|
|
None,
|
|
None,
|
|
None,
|
|
global_step,
|
|
epoch,
|
|
OUT_PATH,
|
|
keep_all_best=keep_all_best,
|
|
keep_after=keep_after,
|
|
model_losses=eval_avg_loss_dict,
|
|
scaler=scaler.state_dict() if c.mixed_precision else None
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_arguments(sys.argv)
|
|
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
|
args, model_class='vocoder')
|
|
|
|
try:
|
|
main(args)
|
|
except KeyboardInterrupt:
|
|
remove_experiment_folder(OUT_PATH)
|
|
try:
|
|
sys.exit(0)
|
|
except SystemExit:
|
|
os._exit(0) # pylint: disable=protected-access
|
|
except Exception: # pylint: disable=broad-except
|
|
remove_experiment_folder(OUT_PATH)
|
|
traceback.print_exc()
|
|
sys.exit(1)
|