mirror of https://github.com/coqui-ai/TTS.git
Implement unified trainer
This commit is contained in:
parent
6d7b5fbcde
commit
c7aad884cd
|
@ -13,8 +13,8 @@ from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
|||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.trainer import init_training
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
|
|
|
@ -1,27 +1,13 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from TTS.tts.trainer_tts import TrainerTTS
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.generic_utils import remove_experiment_folder
|
||||
from TTS.trainer import Trainer, init_training
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=output_path)
|
||||
trainer.fit()
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(output_path)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(output_path)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from TTS.trainer import Trainer, init_training
|
||||
from TTS.utils.generic_utils import remove_experiment_folder
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
trainer.fit()
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(output_path)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(output_path)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,638 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# TODO: mixed precision training
|
||||
"""Trains GAN based vocoder model."""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from inspect import signature
|
||||
|
||||
import torch
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.utils.generic_utils import plot_results, setup_discriminator, setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
||||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
loader = None
|
||||
if not is_val or c.run_eval:
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
return_pairs=c.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in c else False,
|
||||
is_training=not is_val,
|
||||
return_segments=not is_val,
|
||||
use_noise_augment=c.use_noise_augment,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_val else c.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
if isinstance(data[0], list):
|
||||
x_G, y_G = data[0]
|
||||
x_D, y_D = data[1]
|
||||
if use_cuda:
|
||||
x_G = x_G.cuda(non_blocking=True)
|
||||
y_G = y_G.cuda(non_blocking=True)
|
||||
x_D = x_D.cuda(non_blocking=True)
|
||||
y_D = y_D.cuda(non_blocking=True)
|
||||
return x_G, y_G, x_D, y_D
|
||||
x, y = data
|
||||
if use_cuda:
|
||||
x = x.cuda(non_blocking=True)
|
||||
y = y.cuda(non_blocking=True)
|
||||
return x, y, None, None
|
||||
|
||||
|
||||
def train(
|
||||
model_G,
|
||||
criterion_G,
|
||||
optimizer_G,
|
||||
model_D,
|
||||
criterion_D,
|
||||
optimizer_D,
|
||||
scheduler_G,
|
||||
scheduler_D,
|
||||
ap,
|
||||
global_step,
|
||||
epoch,
|
||||
):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model_G.train()
|
||||
model_D.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
c_G, y_G, c_D, y_D = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
##############################
|
||||
# GENERATOR
|
||||
##############################
|
||||
|
||||
# generator pass
|
||||
y_hat = model_G(c_G)
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
y_hat_vis = y_hat # for visualization
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_hat_vis = y_hat
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
# we don't need scores for real samples for training G since they are always 1
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(
|
||||
y_hat=y_hat,
|
||||
y=y_G,
|
||||
scores_fake=scores_fake,
|
||||
feats_fake=feats_fake,
|
||||
feats_real=feats_real,
|
||||
y_hat_sub=y_hat_sub,
|
||||
y_sub=y_G_sub,
|
||||
)
|
||||
loss_G = loss_G_dict["G_loss"]
|
||||
|
||||
# optimizer generator
|
||||
optimizer_G.zero_grad()
|
||||
loss_G.backward()
|
||||
if c.gen_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad)
|
||||
optimizer_G.step()
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
if isinstance(value, int):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
##############################
|
||||
if global_step >= c.steps_to_start_discriminator:
|
||||
# discriminator pass
|
||||
if c.diff_samples_for_G_and_D:
|
||||
# use a different sample than generator
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_D)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
else:
|
||||
# use the same samples as generator
|
||||
c_D = c_G.clone()
|
||||
y_D = y_G.clone()
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat.detach().clone(), c_D)
|
||||
D_out_real = model_D(y_D, c_D)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat.detach())
|
||||
D_out_real = model_D(y_D)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
# model_D returns scores and features
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
# model D returns only scores
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||
loss_D = loss_D_dict["D_loss"]
|
||||
|
||||
# optimizer discriminator
|
||||
optimizer_D.zero_grad()
|
||||
loss_D.backward()
|
||||
if c.disc_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad)
|
||||
optimizer_D.step()
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# get current learning rates
|
||||
current_lr_G = list(optimizer_G.param_groups)[0]["lr"]
|
||||
current_lr_D = list(optimizer_D.param_groups)[0]["lr"]
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr_G": current_lr_G,
|
||||
"current_lr_D": current_lr_D,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model_G,
|
||||
optimizer_G,
|
||||
scheduler_G,
|
||||
model_D,
|
||||
optimizer_D,
|
||||
scheduler_D,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat_vis, y_G, ap, global_step, "train")
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_train_audios(global_step, {"train/audio": sample_voice}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
if scheduler_G is not None:
|
||||
scheduler_G.step()
|
||||
|
||||
if scheduler_D is not None:
|
||||
scheduler_D.step()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
if args.rank == 0:
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
# TODO: plot model stats
|
||||
# if c.tb_model_param_stats:
|
||||
# tb_logger.tb_model_weights(model, global_step)
|
||||
torch.cuda.empty_cache()
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||
model_G.eval()
|
||||
model_D.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
end_time = time.time()
|
||||
c_logger.print_eval_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
c_G, y_G, _, _ = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
##############################
|
||||
# GENERATOR
|
||||
##############################
|
||||
|
||||
# generator pass
|
||||
y_hat = model_G(c_G)[:, :, : y_G.size(2)]
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
feats_fake, feats_real = None, None
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub)
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
##############################
|
||||
|
||||
if global_step >= c.steps_to_start_discriminator:
|
||||
# discriminator pass
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_G)[:, :, : y_G.size(2)]
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat.detach(), c_G)
|
||||
D_out_real = model_D(y_G, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat.detach())
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_loader_time"] = loader_time
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat, y_G, ap, global_step, "eval")
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
real_waveform = y_G[0].squeeze(0).cpu().numpy()
|
||||
tb_logger.tb_eval_audios(
|
||||
global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"]
|
||||
)
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
||||
# synthesize a full voice
|
||||
data_loader.return_segments = False
|
||||
torch.cuda.empty_cache()
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global train_data, eval_data
|
||||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# setup models
|
||||
model_gen = setup_generator(c)
|
||||
model_disc = setup_discriminator(c)
|
||||
|
||||
# setup criterion
|
||||
criterion_gen = GeneratorLoss(c)
|
||||
criterion_disc = DiscriminatorLoss(c)
|
||||
|
||||
if use_cuda:
|
||||
model_gen.cuda()
|
||||
criterion_gen.cuda()
|
||||
model_disc.cuda()
|
||||
criterion_disc.cuda()
|
||||
|
||||
# setup optimizers
|
||||
# TODO: allow loading custom optimizers
|
||||
optimizer_gen = None
|
||||
optimizer_disc = None
|
||||
optimizer_gen = getattr(torch.optim, c.optimizer)
|
||||
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
|
||||
optimizer_disc = getattr(torch.optim, c.optimizer)
|
||||
|
||||
if c.discriminator_model == "hifigan_discriminator":
|
||||
optimizer_disc = optimizer_disc(
|
||||
itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()),
|
||||
lr=c.lr_disc,
|
||||
**c.optimizer_params,
|
||||
)
|
||||
else:
|
||||
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
|
||||
|
||||
# schedulers
|
||||
scheduler_gen = None
|
||||
scheduler_disc = None
|
||||
if "lr_scheduler_gen" in c:
|
||||
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
|
||||
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
|
||||
if "lr_scheduler_disc" in c:
|
||||
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
|
||||
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
|
||||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Generator Model...")
|
||||
model_gen.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Generator Optimizer...")
|
||||
optimizer_gen.load_state_dict(checkpoint["optimizer"])
|
||||
print(" > Restoring Discriminator Model...")
|
||||
model_disc.load_state_dict(checkpoint["model_disc"])
|
||||
print(" > Restoring Discriminator Optimizer...")
|
||||
optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
|
||||
# restore schedulers if it is a continuing training.
|
||||
if args.continue_path != "":
|
||||
if "scheduler" in checkpoint and scheduler_gen is not None:
|
||||
print(" > Restoring Generator LR Scheduler...")
|
||||
scheduler_gen.load_state_dict(checkpoint["scheduler"])
|
||||
# NOTE: Not sure if necessary
|
||||
scheduler_gen.optimizer = optimizer_gen
|
||||
if "scheduler_disc" in checkpoint and scheduler_disc is not None:
|
||||
print(" > Restoring Discriminator LR Scheduler...")
|
||||
scheduler_disc.load_state_dict(checkpoint["scheduler_disc"])
|
||||
scheduler_disc.optimizer = optimizer_disc
|
||||
if c.lr_scheduler_disc == "ExponentialLR":
|
||||
scheduler_disc.last_epoch = checkpoint["epoch"]
|
||||
except RuntimeError:
|
||||
# restore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model_gen.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model_gen.load_state_dict(model_dict)
|
||||
|
||||
model_dict = model_disc.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c)
|
||||
model_disc.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# reset lr if not countinuining training.
|
||||
if args.continue_path == "":
|
||||
for group in optimizer_gen.param_groups:
|
||||
group["lr"] = c.lr_gen
|
||||
|
||||
for group in optimizer_disc.param_groups:
|
||||
group["lr"] = c.lr_disc
|
||||
|
||||
print(f" > Model restored from step {checkpoint['step']:d}", flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model_gen = DDP_th(model_gen, device_ids=[args.rank])
|
||||
model_disc = DDP_th(model_disc, device_ids=[args.rank])
|
||||
|
||||
num_params = count_parameters(model_gen)
|
||||
print(" > Generator has {} parameters".format(num_params), flush=True)
|
||||
num_params = count_parameters(model_disc)
|
||||
print(" > Discriminator has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with best loss of {best_loss}.")
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(
|
||||
model_gen,
|
||||
criterion_gen,
|
||||
optimizer_gen,
|
||||
model_disc,
|
||||
criterion_disc,
|
||||
optimizer_disc,
|
||||
scheduler_gen,
|
||||
scheduler_disc,
|
||||
ap,
|
||||
global_step,
|
||||
epoch,
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model_gen,
|
||||
optimizer_gen,
|
||||
scheduler_gen,
|
||||
model_disc,
|
||||
optimizer_disc,
|
||||
scheduler_disc,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -1,431 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Trains WaveGrad vocoder models."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
||||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
is_training=not is_val,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.batch_size,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# return a whole audio segment
|
||||
m, x = data
|
||||
x = x.unsqueeze(1)
|
||||
if use_cuda:
|
||||
m = m.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
return m, x
|
||||
|
||||
|
||||
def format_test_data(data):
|
||||
# return a whole audio segment
|
||||
m, x = data
|
||||
m = m[None, ...]
|
||||
x = x[None, None, ...]
|
||||
if use_cuda:
|
||||
m = m.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
return m, x
|
||||
|
||||
|
||||
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
# setup noise schedule
|
||||
noise_schedule = c["train_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
if hasattr(model, "module"):
|
||||
model.module.compute_noise_level(betas)
|
||||
else:
|
||||
model.compute_noise_level(betas)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
m, x = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
# compute noisy input
|
||||
if hasattr(model, "module"):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {"wavegrad_loss": loss}
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss).any():
|
||||
raise RuntimeError(f"Detected NaN loss at step {global_step}.")
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.grad_clip_norm_(model.parameters(), c.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
# schedule update
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
# disconnect loss values
|
||||
loss_dict = dict()
|
||||
for key, value in loss_wavegrad_dict.items():
|
||||
if isinstance(value, int):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
# epoch/step timing
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# get current learning rates
|
||||
current_lr = list(optimizer.param_groups)[0]["lr"]
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": current_lr,
|
||||
"grad_norm": grad_norm.item(),
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
if args.rank == 0:
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
# TODO: plot model stats
|
||||
if c.tb_model_param_stats and args.rank == 0:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
end_time = time.time()
|
||||
c_logger.print_eval_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
m, x = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
# compute noisy input
|
||||
if hasattr(model, "module"):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {"wavegrad_loss": loss}
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_wavegrad_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_loader_time"] = loader_time
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
data_loader.dataset.return_segments = False
|
||||
samples = data_loader.dataset.load_test_samples(1)
|
||||
m, x = format_test_data(samples[0])
|
||||
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = c["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
if hasattr(model, "module"):
|
||||
model.module.compute_noise_level(betas)
|
||||
# compute voice
|
||||
x_pred = model.module.inference(m)
|
||||
else:
|
||||
model.compute_noise_level(betas)
|
||||
# compute voice
|
||||
x_pred = model.inference(m)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(x_pred, x, ap, global_step, "eval")
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"])
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
data_loader.dataset.return_segments = True
|
||||
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global train_data, eval_data
|
||||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# setup models
|
||||
model = setup_generator(c)
|
||||
|
||||
# scaler for mixed_precision
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
|
||||
# setup optimizers
|
||||
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
# schedulers
|
||||
scheduler = None
|
||||
if "lr_scheduler" in c:
|
||||
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
||||
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
||||
|
||||
# setup criterion
|
||||
criterion = torch.nn.L1Loss().cuda()
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scheduler" in checkpoint:
|
||||
print(" > Restoring LR Scheduler...")
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
# NOTE: Not sure if necessary
|
||||
scheduler.optimizer = optimizer
|
||||
if "scaler" in checkpoint and c.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except RuntimeError:
|
||||
# retore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# reset lr if not countinuining training.
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = c.lr
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -1,431 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Train WaveRNN vocoder model."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
# from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
||||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=c.padding,
|
||||
mode=c.mode,
|
||||
mulaw=c.mulaw,
|
||||
is_training=not is_val,
|
||||
verbose=verbose,
|
||||
)
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate,
|
||||
batch_size=c.batch_size,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
x_input = data[0]
|
||||
mels = data[1]
|
||||
y_coarse = data[2]
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
x_input = x_input.cuda(non_blocking=True)
|
||||
mels = mels.cuda(non_blocking=True)
|
||||
y_coarse = y_coarse.cuda(non_blocking=True)
|
||||
|
||||
return x_input, mels, y_coarse
|
||||
|
||||
|
||||
def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch):
|
||||
# create train loader
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
# train loop
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
x_input, mels, y_coarse = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if c.mixed_precision:
|
||||
# mixed precision training
|
||||
with torch.cuda.amp.autocast():
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
# full precision training
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
if loss.item() is None:
|
||||
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||
loss.backward()
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
# get the current learning rate
|
||||
cur_lr = list(optimizer.param_groups)[0]["lr"]
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
update_train_values = dict()
|
||||
loss_dict = dict()
|
||||
loss_dict["model_loss"] = loss.item()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": cur_lr,
|
||||
}
|
||||
c_logger.print_train_step(
|
||||
batch_n_iter,
|
||||
num_iter,
|
||||
global_step,
|
||||
log_dict,
|
||||
loss_dict,
|
||||
keep_avg.avg_values,
|
||||
)
|
||||
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr": cur_lr, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_step_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
# synthesize a full voice
|
||||
rand_idx = random.randrange(0, len(train_data))
|
||||
wav_path = (
|
||||
train_data[rand_idx] if not isinstance(train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
|
||||
)
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
ground_mel = torch.FloatTensor(ground_mel)
|
||||
if use_cuda:
|
||||
ground_mel = ground_mel.cuda(non_blocking=True)
|
||||
sample_wav = model.inference(
|
||||
ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"train/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_train_audios(global_step, {"train/audio": sample_wav}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
# TODO: plot model stats
|
||||
# if c.tb_model_param_stats:
|
||||
# tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, ap, global_step, epoch):
|
||||
# create train loader
|
||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
end_time = time.time()
|
||||
c_logger.print_eval_start()
|
||||
with torch.no_grad():
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
# format data
|
||||
x_input, mels, y_coarse = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
# Compute avg loss
|
||||
# if num_gpus > 1:
|
||||
# loss = reduce_tensor(loss.data, num_gpus)
|
||||
loss_dict = dict()
|
||||
loss_dict["model_loss"] = loss.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_loader_time"] = loader_time
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if epoch % c.test_every_epochs == 0 and epoch != 0:
|
||||
# synthesize a full voice
|
||||
rand_idx = random.randrange(0, len(eval_data))
|
||||
wav_path = eval_data[rand_idx] if not isinstance(eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
ground_mel = torch.FloatTensor(ground_mel)
|
||||
if use_cuda:
|
||||
ground_mel = ground_mel.cuda(non_blocking=True)
|
||||
sample_wav = model.inference(
|
||||
ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_wav}, c.audio["sample_rate"])
|
||||
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global train_data, eval_data
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio.to_dict())
|
||||
|
||||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
# setup model
|
||||
model_wavernn = setup_generator(c)
|
||||
|
||||
# setup amp scaler
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
|
||||
# define train functions
|
||||
if c.mode == "mold":
|
||||
criterion = discretized_mix_logistic_loss
|
||||
elif c.mode == "gauss":
|
||||
criterion = gaussian_loss
|
||||
elif isinstance(c.mode, int):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
if use_cuda:
|
||||
model_wavernn.cuda()
|
||||
if isinstance(c.mode, int):
|
||||
criterion.cuda()
|
||||
|
||||
optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
scheduler = None
|
||||
if "lr_scheduler" in c:
|
||||
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
||||
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
||||
# slow start for the first 5 epochs
|
||||
# lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
|
||||
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
# restore any checkpoint
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model_wavernn.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scheduler" in checkpoint:
|
||||
print(" > Restoring Generator LR Scheduler...")
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
scheduler.optimizer = optimizer
|
||||
if "scaler" in checkpoint and c.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except RuntimeError:
|
||||
# retore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model_wavernn.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model_wavernn.load_state_dict(model_dict)
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
# DISTRIBUTED
|
||||
# if num_gpus > 1:
|
||||
# model = apply_gradient_allreduce(model)
|
||||
|
||||
num_parameters = count_parameters(model_wavernn)
|
||||
print(" > Model has {} parameters".format(num_parameters), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict["avg_model_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model_wavernn,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
999
TTS/trainer.py
999
TTS/trainer.py
File diff suppressed because it is too large
Load Diff
|
@ -1,245 +0,0 @@
|
|||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
|
||||
class TacotronAbstract(ABC, nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
d_vector_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
gradual_training=None,
|
||||
):
|
||||
"""Abstract Tacotron class"""
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.use_gst = use_gst
|
||||
self.gst = gst
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.double_decoder_consistency = double_decoder_consistency
|
||||
self.ddc_r = ddc_r
|
||||
self.attn_type = attn_type
|
||||
self.attn_win = attn_win
|
||||
self.attn_norm = attn_norm
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.prenet_dropout_at_inference = prenet_dropout_at_inference
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.forward_attn_mask = forward_attn_mask
|
||||
self.location_attn = location_attn
|
||||
self.attn_K = attn_K
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.encoder_in_features = encoder_in_features
|
||||
self.decoder_in_features = decoder_in_features
|
||||
self.d_vector_dim = d_vector_dim
|
||||
self.gradual_training = gradual_training
|
||||
|
||||
# layers
|
||||
self.embedding = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.postnet = None
|
||||
|
||||
# multispeaker
|
||||
if self.d_vector_dim is None:
|
||||
# if d_vector_dim is None we need use the nn.Embedding, with default d_vector_dim
|
||||
self.use_d_vectors = False
|
||||
else:
|
||||
# if d_vector_dim is not None we need use speaker embedding per sample
|
||||
self.use_d_vectors = True
|
||||
|
||||
# global style token
|
||||
if self.gst and use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
# model states
|
||||
self.embedded_speakers = None
|
||||
self.embedded_speakers_projected = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
@staticmethod
|
||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
||||
def _init_states(self):
|
||||
self.embedded_speakers = None
|
||||
self.embedded_speakers_projected = None
|
||||
|
||||
def _init_backward_decoder(self):
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
||||
def _init_coarse_decoder(self):
|
||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||
self.coarse_decoder.r_init = self.ddc_r
|
||||
self.coarse_decoder.set_r(self.ddc_r)
|
||||
|
||||
#############################
|
||||
# CORE FUNCTIONS
|
||||
#############################
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
pass
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
self.decoder.set_r(state["r"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_masks(self, text_lengths, mel_lengths):
|
||||
"""Compute masks against sequence paddings."""
|
||||
# B x T_in_max (boolean)
|
||||
input_mask = sequence_mask(text_lengths)
|
||||
output_mask = None
|
||||
if mel_lengths is not None:
|
||||
max_len = mel_lengths.max()
|
||||
r = self.decoder.r
|
||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len)
|
||||
return input_mask, output_mask
|
||||
|
||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
"""Run backwards decoder"""
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
|
||||
)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
|
||||
"""Double Decoder Consistency"""
|
||||
T = mel_specs.shape[1]
|
||||
if T % self.coarse_decoder.r > 0:
|
||||
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
|
||||
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask
|
||||
)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
|
||||
).transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
return decoder_outputs_backward, alignments_backward
|
||||
|
||||
#############################
|
||||
# EMBEDDING FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
"""Compute speaker embedding vectors"""
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
||||
self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
"""Compute global style token"""
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
for k_token, v_amplifier in style_input.items():
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + embedded_speakers_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, embedded_speakers_], dim=-1)
|
||||
return outputs
|
||||
|
||||
#############################
|
||||
# CALLBACKS
|
||||
#############################
|
||||
|
||||
def on_epoch_start(self, trainer):
|
||||
"""Callback for setting values wrt gradual training schedule.
|
||||
|
||||
Args:
|
||||
trainer (TrainerTTS): TTS trainer object that is used to train this model.
|
||||
"""
|
||||
if self.gradual_training:
|
||||
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
|
||||
trainer.config.r = r
|
||||
self.decoder.set_r(r)
|
||||
if trainer.config.bidirectional_decoder:
|
||||
trainer.model.decoder_backward.set_r(r)
|
||||
trainer.train_loader = trainer.setup_train_dataloader(self.ap, self.model.decoder.r, verbose=True)
|
||||
trainer.eval_loader = trainer.setup_eval_dataloder(self.ap, self.model.decoder.r)
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
|
@ -1,709 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.trainer import TrainerAbstract
|
||||
from TTS.tts.datasets import TTSDataset, load_meta_data
|
||||
from TTS.tts.layers import setup_loss
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.utils.training import check_update, setup_torch_training_env
|
||||
|
||||
|
||||
# pylint: disable=import-outside-toplevel, too-many-public-methods
|
||||
|
||||
class TrainerTTS(TrainerAbstract):
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: Union[Coqpit, Namespace],
|
||||
config: Coqpit,
|
||||
c_logger: ConsoleLogger = None,
|
||||
tb_logger: TensorboardLogger = None,
|
||||
model: nn.Module = None,
|
||||
output_path: str = None,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||
if tb_logger is None:
|
||||
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
|
||||
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
else:
|
||||
self.tb_logger = tb_logger
|
||||
self.output_path = output_path
|
||||
|
||||
self.total_steps_done = 0
|
||||
self.epochs_done = 0
|
||||
self.restore_step = 0
|
||||
self.best_loss = float("inf")
|
||||
self.train_loader = None
|
||||
self.eval_loader = None
|
||||
self.output_audio_path = os.path.join(output_path, "test_audios")
|
||||
|
||||
self.keep_avg_train = None
|
||||
self.keep_avg_eval = None
|
||||
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
||||
# model, audio processor, datasets, loss
|
||||
# init audio processor
|
||||
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
||||
|
||||
# init character processor
|
||||
self.model_characters = self.get_character_processor(self.config)
|
||||
|
||||
# load dataset samples
|
||||
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
|
||||
|
||||
# default speaker manager
|
||||
self.speaker_manager = self.get_speaker_manager(self.config, args.restore_path, output_path, self.data_train)
|
||||
|
||||
# init TTS model
|
||||
if model is not None:
|
||||
self.model = model
|
||||
else:
|
||||
self.model = self.get_model(
|
||||
len(self.model_characters),
|
||||
self.speaker_manager.num_speakers,
|
||||
self.config,
|
||||
self.speaker_manager.d_vector_dim if self.speaker_manager.d_vectors else None,
|
||||
)
|
||||
|
||||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.config)
|
||||
|
||||
# DISTRUBUTED
|
||||
if self.num_gpus > 1:
|
||||
init_distributed(
|
||||
args.rank,
|
||||
self.num_gpus,
|
||||
args.group_id,
|
||||
self.config.distributed_backend,
|
||||
self.config.distributed_url,
|
||||
)
|
||||
|
||||
if self.use_cuda:
|
||||
self.model.cuda()
|
||||
self.criterion.cuda()
|
||||
|
||||
# scalers for mixed precision training
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
|
||||
|
||||
# setup optimizer
|
||||
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||
|
||||
if self.args.restore_path:
|
||||
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
|
||||
self.config, args.restore_path, self.model, self.optimizer, self.scaler
|
||||
)
|
||||
|
||||
# setup scheduler
|
||||
self.scheduler = self.get_scheduler(self.config, self.optimizer)
|
||||
|
||||
# DISTRUBUTED
|
||||
if self.num_gpus > 1:
|
||||
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
||||
|
||||
# count model size
|
||||
num_params = count_parameters(self.model)
|
||||
print("\n > Model has {} parameters".format(num_params))
|
||||
|
||||
@staticmethod
|
||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit, d_vector_dim: int) -> nn.Module:
|
||||
model = setup_model(num_chars, num_speakers, config, d_vector_dim)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
|
||||
optimizer_name = config.optimizer
|
||||
optimizer_params = config.optimizer_params
|
||||
if optimizer_name.lower() == "radam":
|
||||
module = importlib.import_module("TTS.utils.radam")
|
||||
optimizer = getattr(module, "RAdam")
|
||||
else:
|
||||
optimizer = getattr(torch.optim, optimizer_name)
|
||||
return optimizer(model.parameters(), lr=config.lr, **optimizer_params)
|
||||
|
||||
@staticmethod
|
||||
def get_character_processor(config: Coqpit) -> str:
|
||||
# setup custom characters if set in config file.
|
||||
# TODO: implement CharacterProcessor
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters.to_dict())
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
return model_characters
|
||||
|
||||
@staticmethod
|
||||
def get_speaker_manager(
|
||||
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None
|
||||
) -> SpeakerManager:
|
||||
speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path)
|
||||
return speaker_manager
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler(
|
||||
config: Coqpit, optimizer: torch.optim.Optimizer
|
||||
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
|
||||
lr_scheduler = config.lr_scheduler
|
||||
lr_scheduler_params = config.lr_scheduler_params
|
||||
if lr_scheduler is None:
|
||||
return None
|
||||
if lr_scheduler.lower() == "noamlr":
|
||||
from TTS.utils.training import NoamLR
|
||||
|
||||
scheduler = NoamLR
|
||||
else:
|
||||
scheduler = getattr(torch.optim, lr_scheduler)
|
||||
return scheduler(optimizer, **lr_scheduler_params)
|
||||
|
||||
@staticmethod
|
||||
def get_criterion(config: Coqpit) -> nn.Module:
|
||||
return setup_loss(config)
|
||||
|
||||
def restore_model(
|
||||
self,
|
||||
config: Coqpit,
|
||||
restore_path: str,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scaler: torch.cuda.amp.GradScaler = None,
|
||||
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
|
||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||
checkpoint = torch.load(restore_path)
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scaler" in checkpoint and config.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = self.config.lr
|
||||
print(
|
||||
" > Model restored from step %d" % checkpoint["step"],
|
||||
)
|
||||
restore_step = checkpoint["step"]
|
||||
return model, optimizer, scaler, restore_step
|
||||
|
||||
def _get_loader(
|
||||
self,
|
||||
r: int,
|
||||
ap: AudioProcessor,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
speaker_ids: Union[Dict, List],
|
||||
d_vectors: Union[Dict, List],
|
||||
) -> DataLoader:
|
||||
if is_eval and not self.config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=r,
|
||||
text_cleaner=self.config.text_cleaner,
|
||||
compute_linear_spec=self.config.model.lower() == "tacotron",
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
tp=self.config.characters,
|
||||
add_blank=self.config["add_blank"],
|
||||
batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size,
|
||||
min_seq_len=self.config.min_seq_len,
|
||||
max_seq_len=self.config.max_seq_len,
|
||||
phoneme_cache_path=self.config.phoneme_cache_path,
|
||||
use_phonemes=self.config.use_phonemes,
|
||||
phoneme_language=self.config.phoneme_language,
|
||||
enable_eos_bos=self.config.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_eval,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_ids if self.config.use_speaker_embedding else None,
|
||||
d_vector_mapping=d_vectors
|
||||
if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if self.config.use_phonemes and self.config.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(self.config.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_train_dataloader(
|
||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
|
||||
) -> DataLoader:
|
||||
return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors)
|
||||
|
||||
def get_eval_dataloder(
|
||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict
|
||||
) -> DataLoader:
|
||||
return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors)
|
||||
|
||||
def format_batch(self, batch: List) -> Dict:
|
||||
# setup input batch
|
||||
text_input = batch[0]
|
||||
text_lengths = batch[1]
|
||||
speaker_names = batch[2]
|
||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
||||
mel_input = batch[4]
|
||||
mel_lengths = batch[5]
|
||||
stop_targets = batch[6]
|
||||
item_idx = batch[7]
|
||||
d_vectors = batch[8]
|
||||
speaker_ids = batch[9]
|
||||
attn_mask = batch[10]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
# compute durations from attention masks
|
||||
durations = None
|
||||
if attn_mask is not None:
|
||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
dur[c_idxs] = counts
|
||||
# smooth the durations and set any 0 duration to 1
|
||||
# by cutting off from the largest duration indeces.
|
||||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
|
||||
# set stop targets view, we predict a single stop token per iteration.
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
# dispatch batch to GPU
|
||||
if self.use_cuda:
|
||||
text_input = to_cuda(text_input)
|
||||
text_lengths = to_cuda(text_lengths)
|
||||
mel_input = to_cuda(mel_input)
|
||||
mel_lengths = to_cuda(mel_lengths)
|
||||
linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None
|
||||
stop_targets = to_cuda(stop_targets)
|
||||
attn_mask = to_cuda(attn_mask) if attn_mask is not None else None
|
||||
durations = to_cuda(durations) if attn_mask is not None else None
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = to_cuda(speaker_ids)
|
||||
if d_vectors is not None:
|
||||
d_vectors = to_cuda(d_vectors)
|
||||
|
||||
return {
|
||||
"text_input": text_input,
|
||||
"text_lengths": text_lengths,
|
||||
"speaker_names": speaker_names,
|
||||
"mel_input": mel_input,
|
||||
"mel_lengths": mel_lengths,
|
||||
"linear_input": linear_input,
|
||||
"stop_targets": stop_targets,
|
||||
"attn_mask": attn_mask,
|
||||
"durations": durations,
|
||||
"speaker_ids": speaker_ids,
|
||||
"d_vectors": d_vectors,
|
||||
"max_text_length": max_text_length,
|
||||
"max_spec_length": max_spec_length,
|
||||
"item_idx": item_idx,
|
||||
}
|
||||
|
||||
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
if hasattr(self.model, "module"):
|
||||
return self.model.module.train_step(batch, criterion)
|
||||
return self.model.train_step(batch, criterion)
|
||||
|
||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||
self.on_train_step_start()
|
||||
step_start_time = time.time()
|
||||
|
||||
# format data
|
||||
batch = self.format_batch(batch)
|
||||
loader_time = time.time() - loader_start_time
|
||||
|
||||
# zero-out optimizer
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||
outputs, loss_dict = self._train_step(batch, self.criterion)
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
||||
|
||||
# optimizer step
|
||||
if self.config.mixed_precision:
|
||||
# model optimizer step in mixed precision mode
|
||||
self.scaler.scale(loss_dict["loss"]).backward()
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
# main model optimizer step
|
||||
loss_dict["loss"].backward()
|
||||
grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True)
|
||||
self.optimizer.step()
|
||||
|
||||
step_time = time.time() - step_start_time
|
||||
|
||||
# setup lr
|
||||
if self.config.lr_scheduler:
|
||||
self.scheduler.step()
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
self.keep_avg_train.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||
if self.total_steps_done % self.config.print_step == 0:
|
||||
log_dict = {
|
||||
"max_spec_length": [batch["max_spec_length"], 1], # value, precision
|
||||
"max_text_length": [batch["max_text_length"], 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
self.c_logger.print_train_step(
|
||||
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
||||
)
|
||||
|
||||
if self.args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time,
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
||||
|
||||
if self.total_steps_done % self.config.save_step == 0:
|
||||
if self.config.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
self.model,
|
||||
self.optimizer,
|
||||
self.total_steps_done,
|
||||
self.epochs_done,
|
||||
self.config.r,
|
||||
self.output_path,
|
||||
model_loss=loss_dict["loss"],
|
||||
characters=self.model_characters,
|
||||
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||
)
|
||||
# training visualizations
|
||||
if hasattr(self.model, "module"):
|
||||
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
|
||||
else:
|
||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
|
||||
self.total_steps_done += 1
|
||||
self.on_train_step_end()
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_epoch(self) -> None:
|
||||
self.model.train()
|
||||
epoch_start_time = time.time()
|
||||
if self.use_cuda:
|
||||
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
|
||||
else:
|
||||
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
||||
self.c_logger.print_train_start()
|
||||
for cur_step, batch in enumerate(self.train_loader):
|
||||
loader_start_time = time.time()
|
||||
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
# Plot self.epochs_done Stats
|
||||
if self.args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(self.keep_avg_train.avg_values)
|
||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||
if self.config.tb_model_param_stats:
|
||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||
|
||||
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
|
||||
if hasattr(self.model, "module"):
|
||||
return self.model.module.eval_step(batch, self.criterion)
|
||||
return self.model.eval_step(batch, self.criterion)
|
||||
|
||||
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||
with torch.no_grad():
|
||||
step_start_time = time.time()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||
outputs, loss_dict = self._eval_step(batch)
|
||||
|
||||
step_time = time.time() - step_start_time
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
self.keep_avg_eval.update_values(update_eval_values)
|
||||
|
||||
if self.config.print_eval:
|
||||
self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values)
|
||||
return outputs, loss_dict
|
||||
|
||||
def eval_epoch(self) -> None:
|
||||
self.model.eval()
|
||||
self.c_logger.print_eval_start()
|
||||
loader_start_time = time.time()
|
||||
batch = None
|
||||
for cur_step, batch in enumerate(self.eval_loader):
|
||||
# format data
|
||||
batch = self.format_batch(batch)
|
||||
loader_time = time.time() - loader_start_time
|
||||
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
||||
outputs, _ = self.eval_step(batch, cur_step)
|
||||
# Plot epoch stats and samples from the last batch.
|
||||
if self.args.rank == 0:
|
||||
if hasattr(self.model, "module"):
|
||||
figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs)
|
||||
else:
|
||||
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
||||
|
||||
def test_run(
|
||||
self,
|
||||
) -> None:
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_aux_inputs()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, model_outputs, _ = synthesis(
|
||||
self.model,
|
||||
sen,
|
||||
self.config,
|
||||
self.use_cuda,
|
||||
self.ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
self.ap.save_wav(wav, file_path)
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
||||
|
||||
def _get_aux_inputs(self) -> Dict:
|
||||
# setup speaker_id
|
||||
speaker_id = 0 if self.config.use_speaker_embedding else None
|
||||
# setup d_vector
|
||||
d_vector = (
|
||||
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
|
||||
if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
|
||||
else None
|
||||
)
|
||||
# setup style_mel
|
||||
if self.config.has("gst_style_input"):
|
||||
style_wav = self.config.gst_style_input
|
||||
else:
|
||||
style_wav = None
|
||||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
||||
return aux_inputs
|
||||
|
||||
def fit(self) -> None:
|
||||
if self.restore_step != 0 or self.args.best_path:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||
|
||||
# define data loaders
|
||||
self.train_loader = self.get_train_dataloader(
|
||||
self.config.r,
|
||||
self.ap,
|
||||
self.data_train,
|
||||
verbose=True,
|
||||
speaker_ids=self.speaker_manager.speaker_ids,
|
||||
d_vectors=self.speaker_manager.d_vectors,
|
||||
)
|
||||
self.eval_loader = (
|
||||
self.get_eval_dataloder(
|
||||
self.config.r,
|
||||
self.ap,
|
||||
self.data_train,
|
||||
verbose=True,
|
||||
speaker_ids=self.speaker_manager.speaker_ids,
|
||||
d_vectors=self.speaker_manager.d_vectors,
|
||||
)
|
||||
if self.config.run_eval
|
||||
else None
|
||||
)
|
||||
|
||||
self.total_steps_done = self.restore_step
|
||||
|
||||
for epoch in range(0, self.config.epochs):
|
||||
self.on_epoch_start()
|
||||
self.keep_avg_train = KeepAverage()
|
||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||
self.epochs_done = epoch
|
||||
self.c_logger.print_epoch_start(epoch, self.config.epochs)
|
||||
self.train_epoch()
|
||||
if self.config.run_eval:
|
||||
self.eval_epoch()
|
||||
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
|
||||
self.test_run()
|
||||
self.c_logger.print_epoch_end(
|
||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||
)
|
||||
self.save_best_model()
|
||||
self.on_epoch_end()
|
||||
|
||||
def save_best_model(self) -> None:
|
||||
self.best_loss = save_best_model(
|
||||
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
||||
self.best_loss,
|
||||
self.model,
|
||||
self.optimizer,
|
||||
self.total_steps_done,
|
||||
self.epochs_done,
|
||||
self.config.r,
|
||||
self.output_path,
|
||||
self.model_characters,
|
||||
keep_all_best=self.config.keep_all_best,
|
||||
keep_after=self.config.keep_after,
|
||||
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _setup_logger_config(log_file: str) -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
||||
)
|
||||
|
||||
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
|
||||
if hasattr(self.model, "on_epoch_start"):
|
||||
self.model.on_epoch_start(self)
|
||||
|
||||
if hasattr(self.criterion, "on_epoch_start"):
|
||||
self.criterion.on_epoch_start(self)
|
||||
|
||||
if hasattr(self.optimizer, "on_epoch_start"):
|
||||
self.optimizer.on_epoch_start(self)
|
||||
|
||||
def on_epoch_end(self) -> None: # pylint: disable=no-self-use
|
||||
if hasattr(self.model, "on_epoch_end"):
|
||||
self.model.on_epoch_end(self)
|
||||
|
||||
if hasattr(self.criterion, "on_epoch_end"):
|
||||
self.criterion.on_epoch_end(self)
|
||||
|
||||
if hasattr(self.optimizer, "on_epoch_end"):
|
||||
self.optimizer.on_epoch_end(self)
|
||||
|
||||
def on_train_step_start(self) -> None: # pylint: disable=no-self-use
|
||||
if hasattr(self.model, "on_train_step_start"):
|
||||
self.model.on_train_step_start(self)
|
||||
|
||||
if hasattr(self.criterion, "on_train_step_start"):
|
||||
self.criterion.on_train_step_start(self)
|
||||
|
||||
if hasattr(self.optimizer, "on_train_step_start"):
|
||||
self.optimizer.on_train_step_start(self)
|
||||
|
||||
def on_train_step_end(self) -> None: # pylint: disable=no-self-use
|
||||
if hasattr(self.model, "on_train_step_end"):
|
||||
self.model.on_train_step_end(self)
|
||||
|
||||
if hasattr(self.criterion, "on_train_step_end"):
|
||||
self.criterion.on_train_step_end(self)
|
||||
|
||||
if hasattr(self.optimizer, "on_train_step_end"):
|
||||
self.optimizer.on_train_step_end(self)
|
|
@ -1,182 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Argument parser for training scripts."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
|
||||
|
||||
def init_arguments(argv):
|
||||
"""Parse command line arguments of training scripts.
|
||||
|
||||
Args:
|
||||
argv (list): This is a list of input arguments as given by sys.argv
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help=(
|
||||
"Training output folder to continue training. Used to continue "
|
||||
"a training. If it is used, 'config_path' is ignored."
|
||||
),
|
||||
default="",
|
||||
required="--config_path" not in argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best_path",
|
||||
type=str,
|
||||
help=(
|
||||
"Best model file to be used for extracting best loss."
|
||||
"If not specified, the latest best model in continue path is used"
|
||||
),
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
|
||||
)
|
||||
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
|
||||
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
|
||||
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_last_checkpoint(path):
|
||||
"""Get latest checkpoint or/and best model in path.
|
||||
|
||||
It is based on globbing for `*.pth.tar` and the RegEx
|
||||
`(checkpoint|best_model)_([0-9]+)`.
|
||||
|
||||
Args:
|
||||
path (list): Path to files to be compared.
|
||||
|
||||
Raises:
|
||||
ValueError: If no checkpoint or best_model files are found.
|
||||
|
||||
Returns:
|
||||
last_checkpoint (str): Last checkpoint filename.
|
||||
"""
|
||||
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
|
||||
last_models = {}
|
||||
last_model_nums = {}
|
||||
for key in ["checkpoint", "best_model"]:
|
||||
last_model_num = None
|
||||
last_model = None
|
||||
# pass all the checkpoint files and find
|
||||
# the one with the largest model number suffix.
|
||||
for file_name in file_names:
|
||||
match = re.search(f"{key}_([0-9]+)", file_name)
|
||||
if match is not None:
|
||||
model_num = int(match.groups()[0])
|
||||
if last_model_num is None or model_num > last_model_num:
|
||||
last_model_num = model_num
|
||||
last_model = file_name
|
||||
|
||||
# if there is not checkpoint found above
|
||||
# find the checkpoint with the latest
|
||||
# modification date.
|
||||
key_file_names = [fn for fn in file_names if key in fn]
|
||||
if last_model is None and len(key_file_names) > 0:
|
||||
last_model = max(key_file_names, key=os.path.getctime)
|
||||
last_model_num = torch.load(last_model)["step"]
|
||||
|
||||
if last_model is not None:
|
||||
last_models[key] = last_model
|
||||
last_model_nums[key] = last_model_num
|
||||
|
||||
# check what models were found
|
||||
if not last_models:
|
||||
raise ValueError(f"No models found in continue path {path}!")
|
||||
if "checkpoint" not in last_models: # no checkpoint just best model
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
elif "best_model" not in last_models: # no best model
|
||||
# this shouldn't happen, but let's handle it just in case
|
||||
last_models["best_model"] = None
|
||||
# finally check if last best model is more recent than checkpoint
|
||||
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
|
||||
return last_models["checkpoint"], last_models["best_model"]
|
||||
|
||||
|
||||
def process_args(args):
|
||||
"""Process parsed comand line arguments.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
out_path (str): Path to save models and logging.
|
||||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
logging to the console.
|
||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||
the TensorBoard loggind.
|
||||
"""
|
||||
if isinstance(args, tuple):
|
||||
args, coqpit_overrides = args
|
||||
if args.continue_path:
|
||||
# continue a previous training from its output folder
|
||||
experiment_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
# setup output paths and read configs
|
||||
config = load_config(args.config_path)
|
||||
# override values from command-line args
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
if config.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
# setup rank 0 process in distributed training
|
||||
tb_logger = None
|
||||
if args.rank == 0:
|
||||
os.makedirs(audio_path, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if config.has("characters_config"):
|
||||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
os.chmod(audio_path, 0o775)
|
||||
os.chmod(experiment_path, 0o775)
|
||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||
|
||||
|
||||
def init_training(argv):
|
||||
"""Initialization of a training run."""
|
||||
parser = init_arguments(argv)
|
||||
args = parser.parse_known_args()
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger
|
|
@ -0,0 +1,75 @@
|
|||
class TrainerCallback:
|
||||
def __init__(self, trainer):
|
||||
super().__init__()
|
||||
self.trainer = trainer
|
||||
|
||||
def on_init_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_init_start"):
|
||||
self.trainer.model.on_init_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_init_start"):
|
||||
self.trainer.criterion.on_init_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_init_start"):
|
||||
self.trainer.optimizer.on_init_start(self.trainer)
|
||||
|
||||
def on_init_end(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_init_end"):
|
||||
self.trainer.model.on_init_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_init_end"):
|
||||
self.trainer.criterion.on_init_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_init_end"):
|
||||
self.trainer.optimizer.on_init_end(self.trainer)
|
||||
|
||||
def on_epoch_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_epoch_start"):
|
||||
self.trainer.model.on_epoch_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_epoch_start"):
|
||||
self.trainer.criterion.on_epoch_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_epoch_start"):
|
||||
self.trainer.optimizer.on_epoch_start(self.trainer)
|
||||
|
||||
def on_epoch_end(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_epoch_end"):
|
||||
self.trainer.model.on_epoch_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_epoch_end"):
|
||||
self.trainer.criterion.on_epoch_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_epoch_end"):
|
||||
self.trainer.optimizer.on_epoch_end(self.trainer)
|
||||
|
||||
def on_train_step_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_train_step_start"):
|
||||
self.trainer.model.on_train_step_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_train_step_start"):
|
||||
self.trainer.criterion.on_train_step_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_train_step_start"):
|
||||
self.trainer.optimizer.on_train_step_start(self.trainer)
|
||||
|
||||
def on_train_step_end(self) -> None:
|
||||
|
||||
if hasattr(self.trainer.model, "on_train_step_end"):
|
||||
self.trainer.model.on_train_step_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_train_step_end"):
|
||||
self.trainer.criterion.on_train_step_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_train_step_end"):
|
||||
self.trainer.optimizer.on_train_step_end(self.trainer)
|
||||
|
||||
def on_keyboard_interrupt(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_keyboard_interrupt"):
|
||||
self.trainer.model.on_keyboard_interrupt(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_keyboard_interrupt"):
|
||||
self.trainer.criterion.on_keyboard_interrupt(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"):
|
||||
self.trainer.optimizer.on_keyboard_interrupt(self.trainer)
|
|
@ -1,53 +1,8 @@
|
|||
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class DistributedSampler(Sampler):
|
||||
"""
|
||||
Non shuffling Distributed Sampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
super().__init__(dataset)
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
indices = torch.arange(len(self.dataset)).tolist()
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def reduce_tensor(tensor, num_gpus):
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.utils.training import NoamLR
|
||||
|
||||
|
||||
def is_apex_available():
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
|
||||
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||
torch.backends.cudnn.enabled = cudnn_enable
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
return use_cuda, num_gpus
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer
|
||||
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
|
||||
"""Find, initialize and return a scheduler.
|
||||
|
||||
Args:
|
||||
lr_scheduler (str): Scheduler name.
|
||||
lr_scheduler_params (Dict): Scheduler parameters.
|
||||
optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler.
|
||||
|
||||
Returns:
|
||||
torch.optim.lr_scheduler._LRScheduler: Functional scheduler.
|
||||
"""
|
||||
if lr_scheduler is None:
|
||||
return None
|
||||
if lr_scheduler.lower() == "noamlr":
|
||||
scheduler = NoamLR
|
||||
else:
|
||||
scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler)
|
||||
return scheduler(optimizer, **lr_scheduler_params)
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module
|
||||
) -> torch.optim.Optimizer:
|
||||
"""Find, initialize and return a optimizer.
|
||||
|
||||
Args:
|
||||
optimizer_name (str): Optimizer name.
|
||||
optimizer_params (dict): Optimizer parameters.
|
||||
lr (float): Initial learning rate.
|
||||
model (torch.nn.Module): Model to pass to the optimizer.
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: Functional optimizer.
|
||||
"""
|
||||
if optimizer_name.lower() == "radam":
|
||||
module = importlib.import_module("TTS.utils.radam")
|
||||
optimizer = getattr(module, "RAdam")
|
||||
else:
|
||||
optimizer = getattr(torch.optim, optimizer_name)
|
||||
return optimizer(model.parameters(), lr=lr, **optimizer_params)
|
|
@ -2,17 +2,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
|
||||
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||
torch.backends.cudnn.enabled = cudnn_enable
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
return use_cuda, num_gpus
|
||||
|
||||
|
||||
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||
r"""Check model gradient against unexpected jumps and failures"""
|
||||
skip_flag = False
|
||||
|
@ -41,46 +30,6 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
|||
return grad_norm, skip_flag
|
||||
|
||||
|
||||
def lr_decay(init_lr, global_step, warmup_steps):
|
||||
r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py"""
|
||||
warmup_steps = float(warmup_steps)
|
||||
step = global_step + 1.0
|
||||
lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5)
|
||||
return lr
|
||||
|
||||
|
||||
def adam_weight_decay(optimizer):
|
||||
"""
|
||||
Custom weight decay operation, not effecting grad values.
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
current_lr = group["lr"]
|
||||
weight_decay = group["weight_decay"]
|
||||
factor = -weight_decay * group["lr"]
|
||||
param.data = param.data.add(param.data, alpha=factor)
|
||||
return optimizer, current_lr
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||
"""
|
||||
Skip biases, BatchNorm parameters, rnns.
|
||||
and attention projection layer v
|
||||
"""
|
||||
decay = []
|
||||
no_decay = []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)):
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||
|
@ -107,3 +56,31 @@ def gradual_training_scheduler(global_step, config):
|
|||
if global_step * num_gpus >= values[0]:
|
||||
new_values = values
|
||||
return new_values[1], new_values[2]
|
||||
|
||||
|
||||
def lr_decay(init_lr, global_step, warmup_steps):
|
||||
r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py
|
||||
It is only being used by the Speaker Encoder trainer."""
|
||||
warmup_steps = float(warmup_steps)
|
||||
step = global_step + 1.0
|
||||
lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5)
|
||||
return lr
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||
"""
|
||||
Skip biases, BatchNorm parameters, rnns.
|
||||
and attention projection layer v
|
||||
"""
|
||||
decay = []
|
||||
no_decay = []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)):
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]
|
||||
|
|
Loading…
Reference in New Issue