coqui-tts/TTS/bin/train_wavernn_vocoder.py

513 lines
17 KiB
Python

import argparse
import math
import os
import pickle
import shutil
import sys
import traceback
import time
import glob
import random
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.radam import RAdam
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.training import setup_torch_training_env
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.generic_utils import (
KeepAverage,
count_parameters,
create_experiment_folder,
get_git_branch,
remove_experiment_folder,
set_init_dict,
)
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import (
load_wav_data,
find_feat_files,
load_wav_feat_data,
preprocess_wav_files,
)
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_wavernn
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False):
if is_val and not CONFIG.run_eval:
loader = None
else:
dataset = WaveRNNDataset(
ap=ap,
items=eval_data if is_val else train_data,
seq_len=CONFIG.seq_len,
hop_len=ap.hop_length,
pad=CONFIG.padding,
mode=CONFIG.mode,
is_training=not is_val,
verbose=verbose,
)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(
dataset,
shuffle=True,
collate_fn=dataset.collate,
batch_size=CONFIG.batch_size,
num_workers=CONFIG.num_val_loader_workers
if is_val
else CONFIG.num_loader_workers,
pin_memory=True,
)
return loader
def format_data(data):
# setup input data
x = data[0]
m = data[1]
y = data[2]
# dispatch data to GPU
if use_cuda:
x = x.cuda(non_blocking=True)
m = m.cuda(non_blocking=True)
y = y.cuda(non_blocking=True)
return x, m, y
def train(model, optimizer, criterion, scheduler, ap, global_step, epoch):
# create train loader
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(len(data_loader.dataset) / (CONFIG.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size)
end_time = time.time()
c_logger.print_train_start()
# train loop
print(" > Training", flush=True)
for num_iter, data in enumerate(data_loader):
start_time = time.time()
x, m, y = format_data(data)
loader_time = time.time() - end_time
global_step += 1
##################
# MODEL TRAINING #
##################
y_hat = model(x, m)
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
y = y.float()
y = y.unsqueeze(-1)
# m_scaled, _ = model.upsample(m)
# compute losses
loss = criterion(y_hat, y)
if loss.item() is None:
raise RuntimeError(" [!] None loss. Exiting ...")
optimizer.zero_grad()
loss.backward()
if CONFIG.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG.grad_clip)
optimizer.step()
if scheduler is not None:
scheduler.step()
# get the current learning rate
cur_lr = list(optimizer.param_groups)[0]["lr"]
step_time = time.time() - start_time
epoch_time += step_time
update_train_values = dict()
loss_dict = dict()
loss_dict["model_loss"] = loss.item()
for key, value in loss_dict.items():
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
keep_avg.update_values(update_train_values)
# print training stats
if global_step % CONFIG.print_step == 0:
log_dict = {
"step_time": [step_time, 2],
"loader_time": [loader_time, 4],
"current_lr": cur_lr,
}
c_logger.print_train_step(
batch_n_iter,
num_iter,
global_step,
log_dict,
loss_dict,
keep_avg.avg_values,
)
# plot step stats
if global_step % 10 == 0:
iter_stats = {"lr": cur_lr, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
# save checkpoint
if global_step % CONFIG.save_step == 0:
if CONFIG.checkpoint:
# save model
save_checkpoint(
model,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict,
)
# synthesize a full voice
wav_path = train_data[random.randrange(0, len(train_data))][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav)
sample_wav = model.generate(
ground_mel,
CONFIG.batched,
CONFIG.target_samples,
CONFIG.overlap_samples,
)
predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
figures = {
"train/ground_truth": plot_spectrogram(ground_mel.T),
"train/prediction": plot_spectrogram(predict_mel.T),
}
# Sample audio
tb_logger.tb_train_audios(
global_step, {"train/audio": sample_wav}, CONFIG.audio["sample_rate"]
)
tb_logger.tb_train_figures(global_step, figures)
end_time = time.time()
# print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
# Plot Training Epoch Stats
epoch_stats = {"epoch_time": epoch_time}
epoch_stats.update(keep_avg.avg_values)
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
# TODO: plot model stats
# if c.tb_model_param_stats:
# tb_logger.tb_model_weights(model, global_step)
return keep_avg.avg_values, global_step
@torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch):
# create train loader
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
end_time = time.time()
c_logger.print_eval_start()
with torch.no_grad():
for num_iter, data in enumerate(data_loader):
start_time = time.time()
# format data
x, m, y = format_data(data)
loader_time = time.time() - end_time
global_step += 1
y_hat = model(x, m)
if isinstance(model.mode, int):
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
else:
y = y.float()
y = y.unsqueeze(-1)
loss = criterion(y_hat, y)
# Compute avg loss
# if num_gpus > 1:
# loss = reduce_tensor(loss.data, num_gpus)
loss_dict = dict()
loss_dict["model_loss"] = loss.item()
step_time = time.time() - start_time
epoch_time += step_time
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values["avg_" + key] = value
update_eval_values["avg_loader_time"] = loader_time
update_eval_values["avg_step_time"] = step_time
keep_avg.update_values(update_eval_values)
# print eval stats
if CONFIG.print_eval:
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if epoch % CONFIG.test_every_epochs == 0:
# synthesize a part of data
wav_path = eval_data[random.randrange(0, len(eval_data))][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav[:22000])
sample_wav = model.generate(
ground_mel,
CONFIG.batched,
CONFIG.target_samples,
CONFIG.overlap_samples,
)
predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
figures = {
"eval/ground_truth": plot_spectrogram(ground_mel.T),
"eval/prediction": plot_spectrogram(predict_mel.T),
}
# Sample audio
tb_logger.tb_eval_audios(
global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"]
)
tb_logger.tb_eval_figures(global_step, figures)
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
return keep_avg.avg_values
# FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=global-variable-undefined
global train_data, eval_data
# setup audio processor
ap = AudioProcessor(**CONFIG.audio)
print(f" > Loading wavs from: {CONFIG.data_path}")
if CONFIG.feature_path is not None:
print(f" > Loading features from: {CONFIG.feature_path}")
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size
)
else:
mel_feat_path = os.path.join(OUT_PATH, "mel")
feat_data = find_feat_files(mel_feat_path)
if feat_data:
print(f" > Loading features from: {mel_feat_path}")
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
)
else:
print(f" > No feature data found. Preprocessing...")
# preprocessing feature data from given wav files
preprocess_wav_files(OUT_PATH, CONFIG, ap)
eval_data, train_data = load_wav_feat_data(
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
)
# setup model
model_wavernn = setup_wavernn(CONFIG)
# define train functions
if CONFIG.mode == "mold":
criterion = discretized_mix_logistic_loss
elif CONFIG.mode == "gauss":
criterion = gaussian_loss
elif isinstance(CONFIG.mode, int):
criterion = torch.nn.CrossEntropyLoss()
if use_cuda:
model_wavernn.cuda()
if isinstance(CONFIG.mode, int):
criterion.cuda()
optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0)
scheduler = None
if "lr_scheduler" in CONFIG:
scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler)
scheduler = scheduler(optimizer, **CONFIG.lr_scheduler_params)
# slow start for the first 5 epochs
# lr_lambda = lambda epoch: min(epoch / CONFIG.warmup_steps, 1)
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# restore any checkpoint
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location="cpu")
try:
print(" > Restoring Model...")
model_wavernn.load_state_dict(checkpoint["model"])
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint["optimizer"])
if "scheduler" in checkpoint:
print(" > Restoring Generator LR Scheduler...")
scheduler.load_state_dict(checkpoint["scheduler"])
scheduler.optimizer = optimizer
# TODO: fix resetting restored optimizer lr
# optimizer.load_state_dict(checkpoint["optimizer"])
except RuntimeError:
# retore only matching layers.
print(" > Partial model initialization...")
model_dict = model_wavernn.state_dict()
model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG)
model_wavernn.load_state_dict(model_dict)
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
# DISTRIBUTED
# if num_gpus > 1:
# model = apply_gradient_allreduce(model)
num_parameters = count_parameters(model_wavernn)
print(" > Model has {} parameters".format(num_parameters), flush=True)
if "best_loss" not in locals():
best_loss = float("inf")
global_step = args.restore_step
for epoch in range(0, CONFIG.epochs):
c_logger.print_epoch_start(epoch, CONFIG.epochs)
_, global_step = train(
model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch
)
eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict["avg_model_loss"]
best_loss = save_best_model(
target_loss,
best_loss,
model_wavernn,
optimizer,
scheduler,
None,
None,
None,
global_step,
epoch,
OUT_PATH,
model_losses=eval_avg_loss_dict,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--continue_path",
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default="",
required="--config_path" not in sys.argv,
)
parser.add_argument(
"--restore_path",
type=str,
help="Model file to be restored. Use to finetune a model.",
default="",
)
parser.add_argument(
"--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in sys.argv,
)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.",
)
# DISTRUBUTED
parser.add_argument(
"--rank",
type=int,
default=0,
help="DISTRIBUTED: process rank for distributed training.",
)
parser.add_argument(
"--group_id", type=str, default="", help="DISTRIBUTED: process group id."
)
args = parser.parse_args()
if args.continue_path != "":
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, "config.json")
list_of_files = glob.glob(
args.continue_path + "/*.pth.tar"
) # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
CONFIG = load_config(args.config_path)
# check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = args.continue_path
if args.continue_path == "":
OUT_PATH = create_experiment_folder(
CONFIG.output_path, CONFIG.run_name, args.debug
)
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
c_logger = ConsoleLogger()
if args.rank == 0:
os.makedirs(AUDIO_PATH, exist_ok=True)
new_fields = {}
if args.restore_path:
new_fields["restore_path"] = args.restore_path
new_fields["github_branch"] = get_git_branch()
copy_config_file(
args.config_path, os.path.join(OUT_PATH, "config.json"), new_fields
)
os.chmod(AUDIO_PATH, 0o775)
os.chmod(OUT_PATH, 0o775)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
# write model desc to tensorboard
tb_logger.tb_add_text("model-description", CONFIG["run_description"], 0)
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()
sys.exit(1)