From c7aad884cdd5b99f620390e0c3af58cdbd710418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 18 Jun 2021 13:22:05 +0200 Subject: [PATCH] Implement unified trainer --- TTS/bin/train_encoder.py | 2 +- TTS/bin/train_tts.py | 24 +- TTS/bin/train_vocoder.py | 27 + TTS/bin/train_vocoder_gan.py | 638 ------------------ TTS/bin/train_vocoder_wavegrad.py | 431 ------------ TTS/bin/train_vocoder_wavernn.py | 431 ------------ TTS/trainer.py | 999 ++++++++++++++++++++++++++-- TTS/tts/models/tacotron_abstract.py | 245 ------- TTS/tts/trainer_tts.py | 709 -------------------- TTS/utils/arguments.py | 182 ----- TTS/utils/callbacks.py | 75 +++ TTS/utils/distribute.py | 45 -- TTS/utils/trainer_utils.py | 65 ++ TTS/utils/training.py | 79 +-- 14 files changed, 1128 insertions(+), 2824 deletions(-) create mode 100644 TTS/bin/train_vocoder.py delete mode 100755 TTS/bin/train_vocoder_gan.py delete mode 100644 TTS/bin/train_vocoder_wavegrad.py delete mode 100644 TTS/bin/train_vocoder_wavernn.py delete mode 100644 TTS/tts/models/tacotron_abstract.py delete mode 100644 TTS/tts/trainer_tts.py delete mode 100644 TTS/utils/arguments.py create mode 100644 TTS/utils/callbacks.py create mode 100644 TTS/utils/trainer_utils.py diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 6e4a9b32..38902a18 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -13,8 +13,8 @@ from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model from TTS.speaker_encoder.utils.visual import plot_embeddings +from TTS.trainer import init_training from TTS.tts.datasets import load_meta_data -from TTS.utils.arguments import init_training from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 06765906..c491700d 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,27 +1,13 @@ -import os import sys -import traceback -from TTS.tts.trainer_tts import TrainerTTS -from TTS.utils.arguments import init_training -from TTS.utils.generic_utils import remove_experiment_folder +from TTS.trainer import Trainer, init_training def main(): - try: - args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv) - trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=output_path) - trainer.fit() - except KeyboardInterrupt: - remove_experiment_folder(output_path) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(output_path) - traceback.print_exc() - sys.exit(1) + """Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```""" + args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv) + trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False) + trainer.fit() if __name__ == "__main__": diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py new file mode 100644 index 00000000..868aae2e --- /dev/null +++ b/TTS/bin/train_vocoder.py @@ -0,0 +1,27 @@ +import os +import sys +import traceback + +from TTS.trainer import Trainer, init_training +from TTS.utils.generic_utils import remove_experiment_folder + + +def main(): + try: + args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv) + trainer = Trainer(args, config, output_path, c_logger, tb_logger) + trainer.fit() + except KeyboardInterrupt: + remove_experiment_folder(output_path) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(output_path) + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py deleted file mode 100755 index ea317ef6..00000000 --- a/TTS/bin/train_vocoder_gan.py +++ /dev/null @@ -1,638 +0,0 @@ -#!/usr/bin/env python3 -# TODO: mixed precision training -"""Trains GAN based vocoder model.""" - -import itertools -import os -import sys -import time -import traceback -from inspect import signature - -import torch - -# DISTRIBUTED -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from TTS.utils.arguments import init_training -from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.training import setup_torch_training_env -from TTS.vocoder.datasets.gan_dataset import GANDataset -from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data -from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss -from TTS.vocoder.utils.generic_utils import plot_results, setup_discriminator, setup_generator -from TTS.vocoder.utils.io import save_best_model, save_checkpoint - -use_cuda, num_gpus = setup_torch_training_env(True, True) - - -def setup_loader(ap, is_val=False, verbose=False): - loader = None - if not is_val or c.run_eval: - dataset = GANDataset( - ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad_short=c.pad_short, - conv_pad=c.conv_pad, - return_pairs=c.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in c else False, - is_training=not is_val, - return_segments=not is_val, - use_noise_augment=c.use_noise_augment, - use_cache=c.use_cache, - verbose=verbose, - ) - dataset.shuffle_mapping() - sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=1 if is_val else c.batch_size, - shuffle=num_gpus == 0, - drop_last=False, - sampler=sampler, - num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, - pin_memory=False, - ) - return loader - - -def format_data(data): - if isinstance(data[0], list): - x_G, y_G = data[0] - x_D, y_D = data[1] - if use_cuda: - x_G = x_G.cuda(non_blocking=True) - y_G = y_G.cuda(non_blocking=True) - x_D = x_D.cuda(non_blocking=True) - y_D = y_D.cuda(non_blocking=True) - return x_G, y_G, x_D, y_D - x, y = data - if use_cuda: - x = x.cuda(non_blocking=True) - y = y.cuda(non_blocking=True) - return x, y, None, None - - -def train( - model_G, - criterion_G, - optimizer_G, - model_D, - criterion_D, - optimizer_D, - scheduler_G, - scheduler_D, - ap, - global_step, - epoch, -): - data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) - model_G.train() - model_D.train() - epoch_time = 0 - keep_avg = KeepAverage() - if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) - else: - batch_n_iter = int(len(data_loader.dataset) / c.batch_size) - end_time = time.time() - c_logger.print_train_start() - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - c_G, y_G, c_D, y_D = format_data(data) - loader_time = time.time() - end_time - - global_step += 1 - - ############################## - # GENERATOR - ############################## - - # generator pass - y_hat = model_G(c_G) - y_hat_sub = None - y_G_sub = None - y_hat_vis = y_hat # for visualization - - # PQMF formatting - if y_hat.shape[1] > 1: - y_hat_sub = y_hat - y_hat = model_G.pqmf_synthesis(y_hat) - y_hat_vis = y_hat - y_G_sub = model_G.pqmf_analysis(y_G) - - scores_fake, feats_fake, feats_real = None, None, None - if global_step > c.steps_to_start_discriminator: - - # run D with or without cond. features - if len(signature(model_D.forward).parameters) == 2: - D_out_fake = model_D(y_hat, c_G) - else: - D_out_fake = model_D(y_hat) - D_out_real = None - - if c.use_feat_match_loss: - with torch.no_grad(): - D_out_real = model_D(y_G) - - # format D outputs - if isinstance(D_out_fake, tuple): - scores_fake, feats_fake = D_out_fake - if D_out_real is None: - feats_real = None - else: - # we don't need scores for real samples for training G since they are always 1 - _, feats_real = D_out_real - else: - scores_fake = D_out_fake - - # compute losses - loss_G_dict = criterion_G( - y_hat=y_hat, - y=y_G, - scores_fake=scores_fake, - feats_fake=feats_fake, - feats_real=feats_real, - y_hat_sub=y_hat_sub, - y_sub=y_G_sub, - ) - loss_G = loss_G_dict["G_loss"] - - # optimizer generator - optimizer_G.zero_grad() - loss_G.backward() - if c.gen_clip_grad > 0: - torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad) - optimizer_G.step() - - loss_dict = dict() - for key, value in loss_G_dict.items(): - if isinstance(value, int): - loss_dict[key] = value - else: - loss_dict[key] = value.item() - - ############################## - # DISCRIMINATOR - ############################## - if global_step >= c.steps_to_start_discriminator: - # discriminator pass - if c.diff_samples_for_G_and_D: - # use a different sample than generator - with torch.no_grad(): - y_hat = model_G(c_D) - - # PQMF formatting - if y_hat.shape[1] > 1: - y_hat = model_G.pqmf_synthesis(y_hat) - else: - # use the same samples as generator - c_D = c_G.clone() - y_D = y_G.clone() - - # run D with or without cond. features - if len(signature(model_D.forward).parameters) == 2: - D_out_fake = model_D(y_hat.detach().clone(), c_D) - D_out_real = model_D(y_D, c_D) - else: - D_out_fake = model_D(y_hat.detach()) - D_out_real = model_D(y_D) - - # format D outputs - if isinstance(D_out_fake, tuple): - # model_D returns scores and features - scores_fake, feats_fake = D_out_fake - if D_out_real is None: - scores_real, feats_real = None, None - else: - scores_real, feats_real = D_out_real - else: - # model D returns only scores - scores_fake = D_out_fake - scores_real = D_out_real - - # compute losses - loss_D_dict = criterion_D(scores_fake, scores_real) - loss_D = loss_D_dict["D_loss"] - - # optimizer discriminator - optimizer_D.zero_grad() - loss_D.backward() - if c.disc_clip_grad > 0: - torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad) - optimizer_D.step() - - for key, value in loss_D_dict.items(): - if isinstance(value, (int, float)): - loss_dict[key] = value - else: - loss_dict[key] = value.item() - - step_time = time.time() - start_time - epoch_time += step_time - - # get current learning rates - current_lr_G = list(optimizer_G.param_groups)[0]["lr"] - current_lr_D = list(optimizer_D.param_groups)[0]["lr"] - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - keep_avg.update_values(update_train_values) - - # print training stats - if global_step % c.print_step == 0: - log_dict = { - "step_time": [step_time, 2], - "loader_time": [loader_time, 4], - "current_lr_G": current_lr_G, - "current_lr_D": current_lr_D, - } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # plot step stats - if global_step % 10 == 0: - iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time} - iter_stats.update(loss_dict) - tb_logger.tb_train_step_stats(global_step, iter_stats) - - # save checkpoint - if global_step % c.save_step == 0: - if c.checkpoint: - # save model - save_checkpoint( - model_G, - optimizer_G, - scheduler_G, - model_D, - optimizer_D, - scheduler_D, - global_step, - epoch, - OUT_PATH, - model_losses=loss_dict, - ) - - # compute spectrograms - figures = plot_results(y_hat_vis, y_G, ap, global_step, "train") - tb_logger.tb_train_figures(global_step, figures) - - # Sample audio - sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_train_audios(global_step, {"train/audio": sample_voice}, c.audio["sample_rate"]) - end_time = time.time() - - if scheduler_G is not None: - scheduler_G.step() - - if scheduler_D is not None: - scheduler_D.step() - - # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) - - # Plot Training Epoch Stats - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(keep_avg.avg_values) - if args.rank == 0: - tb_logger.tb_train_epoch_stats(global_step, epoch_stats) - # TODO: plot model stats - # if c.tb_model_param_stats: - # tb_logger.tb_model_weights(model, global_step) - torch.cuda.empty_cache() - return keep_avg.avg_values, global_step - - -@torch.no_grad() -def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch): - data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) - model_G.eval() - model_D.eval() - epoch_time = 0 - keep_avg = KeepAverage() - end_time = time.time() - c_logger.print_eval_start() - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - c_G, y_G, _, _ = format_data(data) - loader_time = time.time() - end_time - - global_step += 1 - - ############################## - # GENERATOR - ############################## - - # generator pass - y_hat = model_G(c_G)[:, :, : y_G.size(2)] - y_hat_sub = None - y_G_sub = None - - # PQMF formatting - if y_hat.shape[1] > 1: - y_hat_sub = y_hat - y_hat = model_G.pqmf_synthesis(y_hat) - y_G_sub = model_G.pqmf_analysis(y_G) - - scores_fake, feats_fake, feats_real = None, None, None - if global_step > c.steps_to_start_discriminator: - - if len(signature(model_D.forward).parameters) == 2: - D_out_fake = model_D(y_hat, c_G) - else: - D_out_fake = model_D(y_hat) - D_out_real = None - - if c.use_feat_match_loss: - with torch.no_grad(): - D_out_real = model_D(y_G) - - # format D outputs - if isinstance(D_out_fake, tuple): - scores_fake, feats_fake = D_out_fake - if D_out_real is None: - feats_real = None - else: - _, feats_real = D_out_real - else: - scores_fake = D_out_fake - feats_fake, feats_real = None, None - - # compute losses - loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub) - - loss_dict = dict() - for key, value in loss_G_dict.items(): - if isinstance(value, (int, float)): - loss_dict[key] = value - else: - loss_dict[key] = value.item() - - ############################## - # DISCRIMINATOR - ############################## - - if global_step >= c.steps_to_start_discriminator: - # discriminator pass - with torch.no_grad(): - y_hat = model_G(c_G)[:, :, : y_G.size(2)] - - # PQMF formatting - if y_hat.shape[1] > 1: - y_hat = model_G.pqmf_synthesis(y_hat) - - # run D with or without cond. features - if len(signature(model_D.forward).parameters) == 2: - D_out_fake = model_D(y_hat.detach(), c_G) - D_out_real = model_D(y_G, c_G) - else: - D_out_fake = model_D(y_hat.detach()) - D_out_real = model_D(y_G) - - # format D outputs - if isinstance(D_out_fake, tuple): - scores_fake, feats_fake = D_out_fake - if D_out_real is None: - scores_real, feats_real = None, None - else: - scores_real, feats_real = D_out_real - else: - scores_fake = D_out_fake - scores_real = D_out_real - - # compute losses - loss_D_dict = criterion_D(scores_fake, scores_real) - - for key, value in loss_D_dict.items(): - if isinstance(value, (int, float)): - loss_dict[key] = value - else: - loss_dict[key] = value.item() - - step_time = time.time() - start_time - epoch_time += step_time - - # update avg stats - update_eval_values = dict() - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - update_eval_values["avg_loader_time"] = loader_time - update_eval_values["avg_step_time"] = step_time - keep_avg.update_values(update_eval_values) - - # print eval stats - if c.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # compute spectrograms - figures = plot_results(y_hat, y_G, ap, global_step, "eval") - tb_logger.tb_eval_figures(global_step, figures) - - # Sample audio - predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy() - real_waveform = y_G[0].squeeze(0).cpu().numpy() - tb_logger.tb_eval_audios( - global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"] - ) - - tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) - - # synthesize a full voice - data_loader.return_segments = False - torch.cuda.empty_cache() - return keep_avg.avg_values - - -def main(args): # pylint: disable=redefined-outer-name - # pylint: disable=global-variable-undefined - global train_data, eval_data - print(f" > Loading wavs from: {c.data_path}") - if c.feature_path is not None: - print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) - else: - eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) - - # setup audio processor - ap = AudioProcessor(**c.audio.to_dict()) - - # DISTRUBUTED - if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) - - # setup models - model_gen = setup_generator(c) - model_disc = setup_discriminator(c) - - # setup criterion - criterion_gen = GeneratorLoss(c) - criterion_disc = DiscriminatorLoss(c) - - if use_cuda: - model_gen.cuda() - criterion_gen.cuda() - model_disc.cuda() - criterion_disc.cuda() - - # setup optimizers - # TODO: allow loading custom optimizers - optimizer_gen = None - optimizer_disc = None - optimizer_gen = getattr(torch.optim, c.optimizer) - optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) - optimizer_disc = getattr(torch.optim, c.optimizer) - - if c.discriminator_model == "hifigan_discriminator": - optimizer_disc = optimizer_disc( - itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), - lr=c.lr_disc, - **c.optimizer_params, - ) - else: - optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) - - # schedulers - scheduler_gen = None - scheduler_disc = None - if "lr_scheduler_gen" in c: - scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) - scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) - if "lr_scheduler_disc" in c: - scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) - scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) - - if args.restore_path: - print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location="cpu") - try: - print(" > Restoring Generator Model...") - model_gen.load_state_dict(checkpoint["model"]) - print(" > Restoring Generator Optimizer...") - optimizer_gen.load_state_dict(checkpoint["optimizer"]) - print(" > Restoring Discriminator Model...") - model_disc.load_state_dict(checkpoint["model_disc"]) - print(" > Restoring Discriminator Optimizer...") - optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) - # restore schedulers if it is a continuing training. - if args.continue_path != "": - if "scheduler" in checkpoint and scheduler_gen is not None: - print(" > Restoring Generator LR Scheduler...") - scheduler_gen.load_state_dict(checkpoint["scheduler"]) - # NOTE: Not sure if necessary - scheduler_gen.optimizer = optimizer_gen - if "scheduler_disc" in checkpoint and scheduler_disc is not None: - print(" > Restoring Discriminator LR Scheduler...") - scheduler_disc.load_state_dict(checkpoint["scheduler_disc"]) - scheduler_disc.optimizer = optimizer_disc - if c.lr_scheduler_disc == "ExponentialLR": - scheduler_disc.last_epoch = checkpoint["epoch"] - except RuntimeError: - # restore only matching layers. - print(" > Partial model initialization...") - model_dict = model_gen.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], c) - model_gen.load_state_dict(model_dict) - - model_dict = model_disc.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c) - model_disc.load_state_dict(model_dict) - del model_dict - - # reset lr if not countinuining training. - if args.continue_path == "": - for group in optimizer_gen.param_groups: - group["lr"] = c.lr_gen - - for group in optimizer_disc.param_groups: - group["lr"] = c.lr_disc - - print(f" > Model restored from step {checkpoint['step']:d}", flush=True) - args.restore_step = checkpoint["step"] - else: - args.restore_step = 0 - - # DISTRUBUTED - if num_gpus > 1: - model_gen = DDP_th(model_gen, device_ids=[args.rank]) - model_disc = DDP_th(model_disc, device_ids=[args.rank]) - - num_params = count_parameters(model_gen) - print(" > Generator has {} parameters".format(num_params), flush=True) - num_params = count_parameters(model_disc) - print(" > Discriminator has {} parameters".format(num_params), flush=True) - - if args.restore_step == 0 or not args.best_path: - best_loss = float("inf") - print(" > Starting with inf best loss.") - else: - print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with best loss of {best_loss}.") - keep_all_best = c.get("keep_all_best", False) - keep_after = c.get("keep_after", 10000) # void if keep_all_best False - - global_step = args.restore_step - for epoch in range(0, c.epochs): - c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train( - model_gen, - criterion_gen, - optimizer_gen, - model_disc, - criterion_disc, - optimizer_disc, - scheduler_gen, - scheduler_disc, - ap, - global_step, - epoch, - ) - eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch) - c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = eval_avg_loss_dict[c.target_loss] - best_loss = save_best_model( - target_loss, - best_loss, - model_gen, - optimizer_gen, - scheduler_gen, - model_disc, - optimizer_disc, - scheduler_disc, - global_step, - epoch, - OUT_PATH, - keep_all_best=keep_all_best, - keep_after=keep_after, - model_losses=eval_avg_loss_dict, - ) - - -if __name__ == "__main__": - args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) - try: - main(args) - except KeyboardInterrupt: - remove_experiment_folder(OUT_PATH) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(OUT_PATH) - traceback.print_exc() - sys.exit(1) diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py deleted file mode 100644 index c8f067ee..00000000 --- a/TTS/bin/train_vocoder_wavegrad.py +++ /dev/null @@ -1,431 +0,0 @@ -#!/usr/bin/env python3 -"""Trains WaveGrad vocoder models.""" - -import os -import sys -import time -import traceback - -import numpy as np -import torch - -# DISTRIBUTED -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.optim import Adam -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from TTS.utils.arguments import init_training -from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.training import setup_torch_training_env -from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data -from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset -from TTS.vocoder.utils.generic_utils import plot_results, setup_generator -from TTS.vocoder.utils.io import save_best_model, save_checkpoint - -use_cuda, num_gpus = setup_torch_training_env(True, True) - - -def setup_loader(ap, is_val=False, verbose=False): - if is_val and not c.run_eval: - loader = None - else: - dataset = WaveGradDataset( - ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad_short=c.pad_short, - conv_pad=c.conv_pad, - is_training=not is_val, - return_segments=True, - use_noise_augment=False, - use_cache=c.use_cache, - verbose=verbose, - ) - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=c.batch_size, - shuffle=num_gpus <= 1, - drop_last=False, - sampler=sampler, - num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, - pin_memory=False, - ) - - return loader - - -def format_data(data): - # return a whole audio segment - m, x = data - x = x.unsqueeze(1) - if use_cuda: - m = m.cuda(non_blocking=True) - x = x.cuda(non_blocking=True) - return m, x - - -def format_test_data(data): - # return a whole audio segment - m, x = data - m = m[None, ...] - x = x[None, None, ...] - if use_cuda: - m = m.cuda(non_blocking=True) - x = x.cuda(non_blocking=True) - return m, x - - -def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch): - data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) - model.train() - epoch_time = 0 - keep_avg = KeepAverage() - if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) - else: - batch_n_iter = int(len(data_loader.dataset) / c.batch_size) - end_time = time.time() - c_logger.print_train_start() - # setup noise schedule - noise_schedule = c["train_noise_schedule"] - betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) - if hasattr(model, "module"): - model.module.compute_noise_level(betas) - else: - model.compute_noise_level(betas) - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - m, x = format_data(data) - loader_time = time.time() - end_time - - global_step += 1 - - with torch.cuda.amp.autocast(enabled=c.mixed_precision): - # compute noisy input - if hasattr(model, "module"): - noise, x_noisy, noise_scale = model.module.compute_y_n(x) - else: - noise, x_noisy, noise_scale = model.compute_y_n(x) - - # forward pass - noise_hat = model(x_noisy, m, noise_scale) - - # compute losses - loss = criterion(noise, noise_hat) - loss_wavegrad_dict = {"wavegrad_loss": loss} - - # check nan loss - if torch.isnan(loss).any(): - raise RuntimeError(f"Detected NaN loss at step {global_step}.") - - optimizer.zero_grad() - - # backward pass with loss scaling - if c.mixed_precision: - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) - scaler.step(optimizer) - scaler.update() - else: - loss.backward() - grad_norm = torch.nn.utils.grad_clip_norm_(model.parameters(), c.clip_grad) - optimizer.step() - - # schedule update - if scheduler is not None: - scheduler.step() - - # disconnect loss values - loss_dict = dict() - for key, value in loss_wavegrad_dict.items(): - if isinstance(value, int): - loss_dict[key] = value - else: - loss_dict[key] = value.item() - - # epoch/step timing - step_time = time.time() - start_time - epoch_time += step_time - - # get current learning rates - current_lr = list(optimizer.param_groups)[0]["lr"] - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - keep_avg.update_values(update_train_values) - - # print training stats - if global_step % c.print_step == 0: - log_dict = { - "step_time": [step_time, 2], - "loader_time": [loader_time, 4], - "current_lr": current_lr, - "grad_norm": grad_norm.item(), - } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # plot step stats - if global_step % 10 == 0: - iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time} - iter_stats.update(loss_dict) - tb_logger.tb_train_step_stats(global_step, iter_stats) - - # save checkpoint - if global_step % c.save_step == 0: - if c.checkpoint: - # save model - save_checkpoint( - model, - optimizer, - scheduler, - None, - None, - None, - global_step, - epoch, - OUT_PATH, - model_losses=loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None, - ) - - end_time = time.time() - - # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) - - # Plot Training Epoch Stats - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(keep_avg.avg_values) - if args.rank == 0: - tb_logger.tb_train_epoch_stats(global_step, epoch_stats) - # TODO: plot model stats - if c.tb_model_param_stats and args.rank == 0: - tb_logger.tb_model_weights(model, global_step) - return keep_avg.avg_values, global_step - - -@torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch): - data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) - model.eval() - epoch_time = 0 - keep_avg = KeepAverage() - end_time = time.time() - c_logger.print_eval_start() - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - m, x = format_data(data) - loader_time = time.time() - end_time - - global_step += 1 - - # compute noisy input - if hasattr(model, "module"): - noise, x_noisy, noise_scale = model.module.compute_y_n(x) - else: - noise, x_noisy, noise_scale = model.compute_y_n(x) - - # forward pass - noise_hat = model(x_noisy, m, noise_scale) - - # compute losses - loss = criterion(noise, noise_hat) - loss_wavegrad_dict = {"wavegrad_loss": loss} - - loss_dict = dict() - for key, value in loss_wavegrad_dict.items(): - if isinstance(value, (int, float)): - loss_dict[key] = value - else: - loss_dict[key] = value.item() - - step_time = time.time() - start_time - epoch_time += step_time - - # update avg stats - update_eval_values = dict() - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - update_eval_values["avg_loader_time"] = loader_time - update_eval_values["avg_step_time"] = step_time - keep_avg.update_values(update_eval_values) - - # print eval stats - if c.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - data_loader.dataset.return_segments = False - samples = data_loader.dataset.load_test_samples(1) - m, x = format_test_data(samples[0]) - - # setup noise schedule and inference - noise_schedule = c["test_noise_schedule"] - betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) - if hasattr(model, "module"): - model.module.compute_noise_level(betas) - # compute voice - x_pred = model.module.inference(m) - else: - model.compute_noise_level(betas) - # compute voice - x_pred = model.inference(m) - - # compute spectrograms - figures = plot_results(x_pred, x, ap, global_step, "eval") - tb_logger.tb_eval_figures(global_step, figures) - - # Sample audio - sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"]) - - tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) - data_loader.dataset.return_segments = True - - return keep_avg.avg_values - - -def main(args): # pylint: disable=redefined-outer-name - # pylint: disable=global-variable-undefined - global train_data, eval_data - print(f" > Loading wavs from: {c.data_path}") - if c.feature_path is not None: - print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) - else: - eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) - - # setup audio processor - ap = AudioProcessor(**c.audio.to_dict()) - - # DISTRUBUTED - if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) - - # setup models - model = setup_generator(c) - - # scaler for mixed_precision - scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None - - # setup optimizers - optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0) - - # schedulers - scheduler = None - if "lr_scheduler" in c: - scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) - scheduler = scheduler(optimizer, **c.lr_scheduler_params) - - # setup criterion - criterion = torch.nn.L1Loss().cuda() - - if use_cuda: - model.cuda() - criterion.cuda() - - if args.restore_path: - print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location="cpu") - try: - print(" > Restoring Model...") - model.load_state_dict(checkpoint["model"]) - print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint["optimizer"]) - if "scheduler" in checkpoint: - print(" > Restoring LR Scheduler...") - scheduler.load_state_dict(checkpoint["scheduler"]) - # NOTE: Not sure if necessary - scheduler.optimizer = optimizer - if "scaler" in checkpoint and c.mixed_precision: - print(" > Restoring AMP Scaler...") - scaler.load_state_dict(checkpoint["scaler"]) - except RuntimeError: - # retore only matching layers. - print(" > Partial model initialization...") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], c) - model.load_state_dict(model_dict) - del model_dict - - # reset lr if not countinuining training. - for group in optimizer.param_groups: - group["lr"] = c.lr - - print(" > Model restored from step %d" % checkpoint["step"], flush=True) - args.restore_step = checkpoint["step"] - else: - args.restore_step = 0 - - # DISTRUBUTED - if num_gpus > 1: - model = DDP_th(model, device_ids=[args.rank]) - - num_params = count_parameters(model) - print(" > WaveGrad has {} parameters".format(num_params), flush=True) - - if args.restore_step == 0 or not args.best_path: - best_loss = float("inf") - print(" > Starting with inf best loss.") - else: - print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get("keep_all_best", False) - keep_after = c.get("keep_after", 10000) # void if keep_all_best False - - global_step = args.restore_step - for epoch in range(0, c.epochs): - c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch) - eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) - c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = eval_avg_loss_dict[c.target_loss] - best_loss = save_best_model( - target_loss, - best_loss, - model, - optimizer, - scheduler, - None, - None, - None, - global_step, - epoch, - OUT_PATH, - keep_all_best=keep_all_best, - keep_after=keep_after, - model_losses=eval_avg_loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None, - ) - - -if __name__ == "__main__": - args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) - try: - main(args) - except KeyboardInterrupt: - remove_experiment_folder(OUT_PATH) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(OUT_PATH) - traceback.print_exc() - sys.exit(1) diff --git a/TTS/bin/train_vocoder_wavernn.py b/TTS/bin/train_vocoder_wavernn.py deleted file mode 100644 index 86a1506a..00000000 --- a/TTS/bin/train_vocoder_wavernn.py +++ /dev/null @@ -1,431 +0,0 @@ -#!/usr/bin/env python3 -"""Train WaveRNN vocoder model.""" - -import os -import random -import sys -import time -import traceback - -import torch -from torch.utils.data import DataLoader - -from TTS.tts.utils.visual import plot_spectrogram -from TTS.utils.arguments import init_training -from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.radam import RAdam -from TTS.utils.training import setup_torch_training_env -from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data -from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset -from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss -from TTS.vocoder.utils.generic_utils import setup_generator -from TTS.vocoder.utils.io import save_best_model, save_checkpoint - -# from torch.utils.data.distributed import DistributedSampler - - -use_cuda, num_gpus = setup_torch_training_env(True, True) - - -def setup_loader(ap, is_val=False, verbose=False): - if is_val and not c.run_eval: - loader = None - else: - dataset = WaveRNNDataset( - ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad=c.padding, - mode=c.mode, - mulaw=c.mulaw, - is_training=not is_val, - verbose=verbose, - ) - # sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader( - dataset, - shuffle=True, - collate_fn=dataset.collate, - batch_size=c.batch_size, - num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, - pin_memory=True, - ) - return loader - - -def format_data(data): - # setup input data - x_input = data[0] - mels = data[1] - y_coarse = data[2] - - # dispatch data to GPU - if use_cuda: - x_input = x_input.cuda(non_blocking=True) - mels = mels.cuda(non_blocking=True) - y_coarse = y_coarse.cuda(non_blocking=True) - - return x_input, mels, y_coarse - - -def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch): - # create train loader - data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) - model.train() - epoch_time = 0 - keep_avg = KeepAverage() - if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) - else: - batch_n_iter = int(len(data_loader.dataset) / c.batch_size) - end_time = time.time() - c_logger.print_train_start() - # train loop - for num_iter, data in enumerate(data_loader): - start_time = time.time() - x_input, mels, y_coarse = format_data(data) - loader_time = time.time() - end_time - global_step += 1 - - optimizer.zero_grad() - - if c.mixed_precision: - # mixed precision training - with torch.cuda.amp.autocast(): - y_hat = model(x_input, mels) - if isinstance(model.mode, int): - y_hat = y_hat.transpose(1, 2).unsqueeze(-1) - else: - y_coarse = y_coarse.float() - y_coarse = y_coarse.unsqueeze(-1) - # compute losses - loss = criterion(y_hat, y_coarse) - scaler.scale(loss).backward() - scaler.unscale_(optimizer) - if c.grad_clip > 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) - scaler.step(optimizer) - scaler.update() - else: - # full precision training - y_hat = model(x_input, mels) - if isinstance(model.mode, int): - y_hat = y_hat.transpose(1, 2).unsqueeze(-1) - else: - y_coarse = y_coarse.float() - y_coarse = y_coarse.unsqueeze(-1) - # compute losses - loss = criterion(y_hat, y_coarse) - if loss.item() is None: - raise RuntimeError(" [!] None loss. Exiting ...") - loss.backward() - if c.grad_clip > 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) - optimizer.step() - - if scheduler is not None: - scheduler.step() - - # get the current learning rate - cur_lr = list(optimizer.param_groups)[0]["lr"] - - step_time = time.time() - start_time - epoch_time += step_time - - update_train_values = dict() - loss_dict = dict() - loss_dict["model_loss"] = loss.item() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - keep_avg.update_values(update_train_values) - - # print training stats - if global_step % c.print_step == 0: - log_dict = { - "step_time": [step_time, 2], - "loader_time": [loader_time, 4], - "current_lr": cur_lr, - } - c_logger.print_train_step( - batch_n_iter, - num_iter, - global_step, - log_dict, - loss_dict, - keep_avg.avg_values, - ) - - # plot step stats - if global_step % 10 == 0: - iter_stats = {"lr": cur_lr, "step_time": step_time} - iter_stats.update(loss_dict) - tb_logger.tb_train_step_stats(global_step, iter_stats) - - # save checkpoint - if global_step % c.save_step == 0: - if c.checkpoint: - # save model - save_checkpoint( - model, - optimizer, - scheduler, - None, - None, - None, - global_step, - epoch, - OUT_PATH, - model_losses=loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None, - ) - - # synthesize a full voice - rand_idx = random.randrange(0, len(train_data)) - wav_path = ( - train_data[rand_idx] if not isinstance(train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] - ) - wav = ap.load_wav(wav_path) - ground_mel = ap.melspectrogram(wav) - ground_mel = torch.FloatTensor(ground_mel) - if use_cuda: - ground_mel = ground_mel.cuda(non_blocking=True) - sample_wav = model.inference( - ground_mel, - c.batched, - c.target_samples, - c.overlap_samples, - ) - predict_mel = ap.melspectrogram(sample_wav) - - # compute spectrograms - figures = { - "train/ground_truth": plot_spectrogram(ground_mel.T), - "train/prediction": plot_spectrogram(predict_mel.T), - } - tb_logger.tb_train_figures(global_step, figures) - - # Sample audio - tb_logger.tb_train_audios(global_step, {"train/audio": sample_wav}, c.audio["sample_rate"]) - end_time = time.time() - - # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) - - # Plot Training Epoch Stats - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(keep_avg.avg_values) - tb_logger.tb_train_epoch_stats(global_step, epoch_stats) - # TODO: plot model stats - # if c.tb_model_param_stats: - # tb_logger.tb_model_weights(model, global_step) - return keep_avg.avg_values, global_step - - -@torch.no_grad() -def evaluate(model, criterion, ap, global_step, epoch): - # create train loader - data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) - model.eval() - epoch_time = 0 - keep_avg = KeepAverage() - end_time = time.time() - c_logger.print_eval_start() - with torch.no_grad(): - for num_iter, data in enumerate(data_loader): - start_time = time.time() - # format data - x_input, mels, y_coarse = format_data(data) - loader_time = time.time() - end_time - global_step += 1 - - y_hat = model(x_input, mels) - if isinstance(model.mode, int): - y_hat = y_hat.transpose(1, 2).unsqueeze(-1) - else: - y_coarse = y_coarse.float() - y_coarse = y_coarse.unsqueeze(-1) - loss = criterion(y_hat, y_coarse) - # Compute avg loss - # if num_gpus > 1: - # loss = reduce_tensor(loss.data, num_gpus) - loss_dict = dict() - loss_dict["model_loss"] = loss.item() - - step_time = time.time() - start_time - epoch_time += step_time - - # update avg stats - update_eval_values = dict() - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - update_eval_values["avg_loader_time"] = loader_time - update_eval_values["avg_step_time"] = step_time - keep_avg.update_values(update_eval_values) - - # print eval stats - if c.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) - - if epoch % c.test_every_epochs == 0 and epoch != 0: - # synthesize a full voice - rand_idx = random.randrange(0, len(eval_data)) - wav_path = eval_data[rand_idx] if not isinstance(eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] - wav = ap.load_wav(wav_path) - ground_mel = ap.melspectrogram(wav) - ground_mel = torch.FloatTensor(ground_mel) - if use_cuda: - ground_mel = ground_mel.cuda(non_blocking=True) - sample_wav = model.inference( - ground_mel, - c.batched, - c.target_samples, - c.overlap_samples, - ) - predict_mel = ap.melspectrogram(sample_wav) - - # Sample audio - tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_wav}, c.audio["sample_rate"]) - - # compute spectrograms - figures = { - "eval/ground_truth": plot_spectrogram(ground_mel.T), - "eval/prediction": plot_spectrogram(predict_mel.T), - } - tb_logger.tb_eval_figures(global_step, figures) - - tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) - return keep_avg.avg_values - - -# FIXME: move args definition/parsing inside of main? -def main(args): # pylint: disable=redefined-outer-name - # pylint: disable=global-variable-undefined - global train_data, eval_data - - # setup audio processor - ap = AudioProcessor(**c.audio.to_dict()) - - print(f" > Loading wavs from: {c.data_path}") - if c.feature_path is not None: - print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) - else: - eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) - # setup model - model_wavernn = setup_generator(c) - - # setup amp scaler - scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None - - # define train functions - if c.mode == "mold": - criterion = discretized_mix_logistic_loss - elif c.mode == "gauss": - criterion = gaussian_loss - elif isinstance(c.mode, int): - criterion = torch.nn.CrossEntropyLoss() - - if use_cuda: - model_wavernn.cuda() - if isinstance(c.mode, int): - criterion.cuda() - - optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0) - - scheduler = None - if "lr_scheduler" in c: - scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) - scheduler = scheduler(optimizer, **c.lr_scheduler_params) - # slow start for the first 5 epochs - # lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1) - # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - - # restore any checkpoint - if args.restore_path: - print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location="cpu") - try: - print(" > Restoring Model...") - model_wavernn.load_state_dict(checkpoint["model"]) - print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint["optimizer"]) - if "scheduler" in checkpoint: - print(" > Restoring Generator LR Scheduler...") - scheduler.load_state_dict(checkpoint["scheduler"]) - scheduler.optimizer = optimizer - if "scaler" in checkpoint and c.mixed_precision: - print(" > Restoring AMP Scaler...") - scaler.load_state_dict(checkpoint["scaler"]) - except RuntimeError: - # retore only matching layers. - print(" > Partial model initialization...") - model_dict = model_wavernn.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], c) - model_wavernn.load_state_dict(model_dict) - - print(" > Model restored from step %d" % checkpoint["step"], flush=True) - args.restore_step = checkpoint["step"] - else: - args.restore_step = 0 - - # DISTRIBUTED - # if num_gpus > 1: - # model = apply_gradient_allreduce(model) - - num_parameters = count_parameters(model_wavernn) - print(" > Model has {} parameters".format(num_parameters), flush=True) - - if args.restore_step == 0 or not args.best_path: - best_loss = float("inf") - print(" > Starting with inf best loss.") - else: - print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get("keep_all_best", False) - keep_after = c.get("keep_after", 10000) # void if keep_all_best False - - global_step = args.restore_step - for epoch in range(0, c.epochs): - c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch) - eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch) - c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = eval_avg_loss_dict["avg_model_loss"] - best_loss = save_best_model( - target_loss, - best_loss, - model_wavernn, - optimizer, - scheduler, - None, - None, - None, - global_step, - epoch, - OUT_PATH, - keep_all_best=keep_all_best, - keep_after=keep_after, - model_losses=eval_avg_loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None, - ) - - -if __name__ == "__main__": - args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) - try: - main(args) - except KeyboardInterrupt: - remove_experiment_folder(OUT_PATH) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(OUT_PATH) - traceback.print_exc() - sys.exit(1) diff --git a/TTS/trainer.py b/TTS/trainer.py index 5c02fdfb..8b7be3d1 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -1,22 +1,52 @@ # -*- coding: utf-8 -*- +import glob import importlib -from abc import ABC, abstractmethod +import logging +import os +import re +import sys +import time +import traceback +from argparse import Namespace from dataclasses import dataclass, field -from typing import Dict, List, Tuple, TypeVar +from typing import Dict, List, Tuple, Union import torch from coqpit import Coqpit - -# DISTRIBUTED from torch import nn +from torch.nn.parallel import DistributedDataParallel as DDP_th +from torch.utils.data import DataLoader -_DataLoader = TypeVar("_DataLoader") +from TTS.config import load_config +from TTS.tts.datasets import load_meta_data +from TTS.tts.models import setup_model as setup_tts_model +from TTS.tts.utils.text.symbols import parse_symbols +from TTS.utils.audio import AudioProcessor +from TTS.utils.callbacks import TrainerCallback +from TTS.utils.distribute import init_distributed +from TTS.utils.generic_utils import ( + KeepAverage, + count_parameters, + create_experiment_folder, + get_git_branch, + remove_experiment_folder, + set_init_dict, + to_cuda, +) +from TTS.utils.io import copy_model_files, save_best_model, save_checkpoint +from TTS.utils.logging import ConsoleLogger, TensorboardLogger +from TTS.utils.trainer_utils import * +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.models import setup_model as setup_vocoder_model + +if is_apex_available(): + from apex import amp @dataclass class TrainingArgs(Coqpit): - """Trainer arguments that are parsed externally (e.g. CLI)""" + """Trainer arguments""" continue_path: str = field( default="", @@ -41,101 +71,926 @@ class TrainingArgs(Coqpit): group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) -# pylint: disable=import-outside-toplevel, too-many-public-methods +class Trainer: + def __init__( + self, + args: Union[Coqpit, Namespace], + config: Coqpit, + output_path: str, + c_logger: ConsoleLogger = None, + tb_logger: TensorboardLogger = None, + model: nn.Module = None, + cudnn_benchmark: bool = False, + ) -> None: + """Simple yet powerful 🐸💬 TTS trainer for PyTorch. It can train all the available `tts` and `vocoder` models + or easily be customized. + Notes: -class TrainerAbstract(ABC): + Supports Automatic Mixed Precision training. If `Apex` is availabe, it automatically picks that, else + it uses PyTorch's native `amp` module. `Apex` may provide more stable training in some cases. + + Args: + + args (Union[Coqpit, Namespace]): Training arguments parsed either from console by `argparse` or `TrainingArgs` + config object. + + config (Coqpit): Model config object. It includes all the values necessary for initializing, training, evaluating + and testing the model. + + output_path (str): Path to the output training folder. All the files are saved under thi path. + + c_logger (ConsoleLogger, optional): Console logger for printing training status. If not provided, the default + console logger is used. Defaults to None. + + tb_logger (TensorboardLogger, optional): Tensorboard logger. If not provided, the default logger is used. + Defaults to None. + + model (nn.Module, optional): Initialized and ready-to-train model. If it is not defined, `Trainer` + initializes a model from the provided config. Defaults to None. + + cudnn_benchmark (bool): enable/disable PyTorch cudnn benchmarking. It is better to disable if the model input + length is changing batch to batch along the training. + + Examples: + + Running trainer on a model. + + >>> args = TrainingArgs(...) + >>> config = HifiganConfig(...) + >>> model = GANModel(config) + >>> trainer = Trainer(args, config, output_path, model=model) + >>> trainer.fit() + + Running trainer on a config. + + >>> config = WavegradConfig(data_path="/home/erogol/nvme/gdrive/Datasets/LJSpeech-1.1/wavs/", output_path=output_path,) + >>> args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) + >>> trainer = Trainer(args, config, output_path, c_logger, tb_logger) + >>> trainer.fit() + + TODO: + - Accumulate gradients b/w batches. + - Deepspeed integration + - Profiler integration. + - Overfitting to a batch. + - TPU training + """ + + # set and initialize Pytorch runtime + self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark) + + if config is None: + # parse config from console arguments + config, output_path, _, c_logger, tb_logger = process_args(args) + + self.output_path = output_path + self.args = args + self.config = config + + # init loggers + self.c_logger = ConsoleLogger() if c_logger is None else c_logger + if tb_logger is None: + self.tb_logger = TensorboardLogger(output_path, model_name=config.model) + self.tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) + else: + self.tb_logger = tb_logger + log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") + self._setup_logger_config(log_file) + + self.total_steps_done = 0 + self.epochs_done = 0 + self.restore_step = 0 + self.best_loss = float("inf") + self.train_loader = None + self.eval_loader = None + self.output_audio_path = os.path.join(output_path, "test_audios") + + self.keep_avg_train = None + self.keep_avg_eval = None + + self.use_apex = self._is_apex_available() + self.use_amp_scaler = self.config.mixed_precision and self.use_cuda + + # init audio processor + self.ap = AudioProcessor(**self.config.audio.to_dict()) + + # load dataset samples + # TODO: refactor this + if "datasets" in self.config: + # load data for `tts` models + self.data_train, self.data_eval = load_meta_data(self.config.datasets) + elif self.config.feature_path is not None: + # load data for `vocoder`models + print(f" > Loading features from: {self.config.feature_path}") + self.data_eval, self.data_train = load_wav_feat_data( + self.config.data_path, self.config.feature_path, self.config.eval_split_size + ) + else: + # load data for `vocoder`models + self.data_eval, self.data_train = load_wav_data(self.config.data_path, self.config.eval_split_size) + + # init TTS model + if model is not None: + self.model = model + else: + self.model = self.get_model(self.config) + + # setup criterion + self.criterion = self.get_criterion(self.model) + + # DISTRUBUTED + if self.num_gpus > 1: + init_distributed( + args.rank, + self.num_gpus, + args.group_id, + self.config.distributed_backend, + self.config.distributed_url, + ) + + if self.use_cuda: + self.model.cuda() + if isinstance(self.criterion, list): + self.criterion = [x.cuda() for x in self.criterion] + else: + self.criterion.cuda() + + # setup optimizer + self.optimizer = self.get_optimizer(self.model, self.config) + + # callback + self.callbacks = TrainerCallback(self) + self.callbacks.on_init_start() + + # init AMP + if self.use_amp_scaler: + if self.use_apex: + self.scaler = None + self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level="O1") + if isinstance(self.optimizer, list): + self.scaler = [torch.cuda.amp.GradScaler()] * len(self.optimizer) + else: + self.scaler = torch.cuda.amp.GradScaler() + else: + self.scaler = None + + if self.args.restore_path: + self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( + self.config, args.restore_path, self.model, self.optimizer, self.scaler + ) + + # setup scheduler + self.scheduler = self.get_scheduler(self.model, self.config, self.optimizer) + + # DISTRUBUTED + if self.num_gpus > 1: + self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) + + # count model size + num_params = count_parameters(self.model) + print("\n > Model has {} parameters".format(num_params)) + + self.callbacks.on_init_end() + + @staticmethod + def get_model(config: Coqpit) -> nn.Module: + """Initialize model from config. + + Args: + config (Coqpit): Model config. + + Returns: + nn.Module: initialized model. + """ + # TODO: better model setup + try: + model = setup_tts_model(config) + except ModuleNotFoundError: + model = setup_vocoder_model(config) + return model + + def restore_model( + self, + config: Coqpit, + restore_path: str, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scaler: torch.cuda.amp.GradScaler = None, + ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: + """Restore training from an old run. It restores model, optimizer, AMP scaler and training stats. + + Args: + config (Coqpit): Model config. + restore_path (str): Path to the restored training run. + model (nn.Module): Model to restored. + optimizer (torch.optim.Optimizer): Optimizer to restore. + scaler (torch.cuda.amp.GradScaler, optional): AMP scaler to restore. Defaults to None. + + Returns: + Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: [description] + """ + + def _restore_list_objs(states, obj): + if isinstance(obj, list): + for idx, state in enumerate(states): + obj[idx].load_state_dict(state) + else: + obj.load_state_dict(states) + return obj + + print(" > Restoring from %s ..." % os.path.basename(restore_path)) + checkpoint = torch.load(restore_path) + try: + print(" > Restoring Model...") + model.load_state_dict(checkpoint["model"]) + print(" > Restoring Optimizer...") + optimizer = _restore_list_objs(checkpoint["optimizer"], optimizer) + if "scaler" in checkpoint and self.use_amp_scaler: + print(" > Restoring AMP Scaler...") + scaler = _restore_list_objs(checkpoint["scaler"], scaler) + except (KeyError, RuntimeError): + print(" > Partial model initialization...") + model_dict = model.state_dict() + model_dict = set_init_dict(model_dict, checkpoint["model"], config) + model.load_state_dict(model_dict) + del model_dict + + if isinstance(self.optimizer, list): + for idx, optim in enumerate(optimizer): + for group in optim.param_groups: + group["lr"] = self.get_lr(model, config)[idx] + else: + for group in optimizer.param_groups: + group["lr"] = self.get_lr(model, config) + print( + " > Model restored from step %d" % checkpoint["step"], + ) + restore_step = checkpoint["step"] + return model, optimizer, scaler, restore_step + + @staticmethod + def _get_loader( + model: nn.Module, + config: Coqpit, + ap: AudioProcessor, + is_eval: bool, + data_items: List, + verbose: bool, + num_gpus: int, + ) -> DataLoader: + if hasattr(model, "get_data_loader"): + loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) + return loader + + def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader: + """Initialize and return a training data loader. + + Args: + ap (AudioProcessor): Audio processor. + data_items (List): Data samples used for training. + verbose (bool): enable/disable printing loader stats at initialization. + + Returns: + DataLoader: Initialized training data loader. + """ + return self._get_loader(self.model, self.config, ap, False, data_items, verbose, self.num_gpus) + + def get_eval_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader: + return self._get_loader(self.model, self.config, ap, True, data_items, verbose, self.num_gpus) + + def format_batch(self, batch: List) -> Dict: + """Format dataloader ouput and return a batch. + + Args: + batch (List): Batch returned by the dataloader. + + Returns: + Dict: Formatted batch. + """ + batch = self.model.format_batch(batch) + if self.use_cuda: + for k, v in batch.items(): + batch[k] = to_cuda(v) + return batch + + @staticmethod + def _model_train_step( + batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None + ) -> Tuple[Dict, Dict]: + """ + Perform a trainig forward step. Compute model outputs and losses. + + Args: + batch (Dict): [description] + model (nn.Module): [description] + criterion (nn.Module): [description] + optimizer_idx (int, optional): [description]. Defaults to None. + + Returns: + Tuple[Dict, Dict]: [description] + """ + input_args = [batch, criterion] + if optimizer_idx is not None: + input_args.append(optimizer_idx) + # unwrap model in DDP training + if hasattr(model, "module"): + return model.module.train_step(*input_args) + return model.train_step(*input_args) + + def _optimize( + self, + batch: Dict, + model: nn.Module, + optimizer: Union[torch.optim.Optimizer, List], + scaler: "AMPScaler", + criterion: nn.Module, + scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access + config: Coqpit, + optimizer_idx: int = None, + ) -> Tuple[Dict, Dict, int, torch.Tensor]: + """Perform a forward - backward pass and run the optimizer. + + Args: + batch (Dict): Input batch. If + model (nn.Module): Model for training. Defaults to None. + optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use. + scaler (AMPScaler): AMP scaler. + criterion (nn.Module): Model's criterion. + scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer. + config (Coqpit): Model config. + optimizer_idx (int, optional): Target optimizer being used. Defaults to None. + + Raises: + RuntimeError: When the loss is NaN. + + Returns: + Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm. + """ + step_start_time = time.time() + # zero-out optimizer + optimizer.zero_grad() + with torch.cuda.amp.autocast(enabled=config.mixed_precision): + if optimizer_idx is not None: + outputs, loss_dict = self._model_train_step(batch, model, criterion, optimizer_idx=optimizer_idx) + else: + outputs, loss_dict = self._model_train_step(batch, model, criterion) + + # skip the rest + if outputs is None: + step_time = time.time() - step_start_time + return None, {}, step_time, 0 + + # check nan loss + if torch.isnan(loss_dict["loss"]).any(): + raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.") + + # set gradient clipping threshold + if "grad_clip" in config and config.grad_clip is not None: + if optimizer_idx is not None: + grad_clip = config.grad_clip[optimizer_idx] + else: + grad_clip = config.grad_clip + else: + grad_clip = 0.0 # meaning no gradient clipping + + # TODO: compute grad norm + if grad_clip <= 0: + grad_norm = 0 + + # optimizer step + update_lr_scheduler = True + if self.use_amp_scaler: + if self.use_apex: + with amp.scale_loss(loss_dict["loss"], self.optimizer) as scaled_loss: + scaled_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + self.config.grad_clip, + ) + else: + # model optimizer step in mixed precision mode + scaler.scale(loss_dict["loss"]).backward() + scaler.unscale_(optimizer) + if grad_clip > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + scale_prev = scaler.get_scale() + scaler.step(optimizer) + scaler.update() + update_lr_scheduler = scale_prev <= scaler.get_scale() + else: + # main model optimizer step + loss_dict["loss"].backward() + if grad_clip > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + optimizer.step() + + step_time = time.time() - step_start_time + + # setup lr + if scheduler is not None and update_lr_scheduler: + scheduler.step() + + # detach losses + loss_dict = self._detach_loss_dict(loss_dict) + if optimizer_idx is not None: + loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss") + loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm + return outputs, loss_dict, step_time, grad_norm + + @staticmethod + def _detach_loss_dict(loss_dict: Dict) -> Dict: + """Detach loss values from autograp. + + Args: + loss_dict (Dict): losses. + + Returns: + Dict: losses detached from autograph. + """ + loss_dict_detached = {} + for key, value in loss_dict.items(): + if isinstance(value, (int, float)): + loss_dict_detached[key] = value + else: + loss_dict_detached[key] = value.item() + return loss_dict_detached + + def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: + """Perform a training step on a batch of inputs and log the process. + + Args: + batch (Dict): Input batch. + batch_n_steps (int): Number of steps needed to complete an epoch. Needed for logging. + step (int): Current step number in this epoch. + loader_start_time (float): The time when the data loading is started. Needed for logging. + + Returns: + Tuple[Dict, Dict]: Model outputs and losses. + """ + self.callbacks.on_train_step_start() + # format data + batch = self.format_batch(batch) + loader_time = time.time() - loader_start_time + + # conteainers to hold model outputs and losses for each optimizer. + outputs_per_optimizer = None + log_dict = {} + loss_dict = {} + if not isinstance(self.optimizer, list): + # training with a single optimizer + outputs, loss_dict_new, step_time, grad_norm = self._optimize( + batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config + ) + loss_dict.update(loss_dict_new) + else: + # training with multiple optimizers (e.g. GAN) + outputs_per_optimizer = [None] * len(self.optimizer) + total_step_time = 0 + for idx, optimizer in enumerate(self.optimizer): + criterion = self.criterion + scaler = self.scaler[idx] if self.use_amp_scaler else None + scheduler = self.scheduler[idx] + outputs, loss_dict_new, step_time, grad_norm = self._optimize( + batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx + ) + # skip the rest if the model returns None + total_step_time += step_time + outputs_per_optimizer[idx] = outputs + # if None, model skipped this optimizer + if loss_dict_new is not None: + loss_dict.update(loss_dict_new) + outputs = outputs_per_optimizer + + # update avg stats + keep_avg_update = dict() + for key, value in log_dict.items(): + keep_avg_update["avg_" + key] = value + keep_avg_update["avg_loader_time"] = loader_time + keep_avg_update["avg_step_time"] = step_time + self.keep_avg_train.update_values(keep_avg_update) + + # print training progress + if self.total_steps_done % self.config.print_step == 0: + # log learning rates + lrs = {} + if isinstance(self.optimizer, list): + for idx, optimizer in enumerate(self.optimizer): + current_lr = self.optimizer[idx].param_groups[0]["lr"] + lrs.update({f"current_lr_{idx}": current_lr}) + else: + current_lr = self.optimizer.param_groups[0]["lr"] + lrs = {"current_lr": current_lr} + log_dict.update(lrs) + if grad_norm > 0: + log_dict.update({"grad_norm": grad_norm}) + # log run-time stats + log_dict.update( + { + "step_time": round(step_time, 4), + "loader_time": round(loader_time, 4), + } + ) + self.c_logger.print_train_step( + batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values + ) + + if self.args.rank == 0: + # Plot Training Iter Stats + # reduce TB load and don't log every step + if self.total_steps_done % self.config.tb_plot_step == 0: + iter_stats = log_dict + iter_stats.update(loss_dict) + self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats) + if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0: + if self.config.checkpoint: + # checkpoint the model + model_loss = ( + loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"] + ) + save_checkpoint( + self.config, + self.model, + self.optimizer, + self.scaler if self.use_amp_scaler else None, + self.total_steps_done, + self.epochs_done, + self.output_path, + model_loss=model_loss, + ) + # training visualizations + figures, audios = None, None + if hasattr(self.model, "module") and hasattr(self.model.module, "train_log"): + figures, audios = self.model.module.train_log(self.ap, batch, outputs) + elif hasattr(self.model, "train_log"): + figures, audios = self.model.train_log(self.ap, batch, outputs) + if figures is not None: + self.tb_logger.tb_train_figures(self.total_steps_done, figures) + if audios is not None: + self.tb_logger.tb_train_audios(self.total_steps_done, audios, self.ap.sample_rate) + self.total_steps_done += 1 + self.callbacks.on_train_step_end() + return outputs, loss_dict + + def train_epoch(self) -> None: + """Main entry point for training. Run training on the whole training samples.""" + self.train_loader = self.get_train_dataloader( + self.ap, + self.data_train, + verbose=True, + ) + self.model.train() + epoch_start_time = time.time() + if self.use_cuda: + batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) + else: + batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) + self.c_logger.print_train_start() + for cur_step, batch in enumerate(self.train_loader): + loader_start_time = time.time() + _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) + epoch_time = time.time() - epoch_start_time + # Plot self.epochs_done Stats + if self.args.rank == 0: + epoch_stats = {"epoch_time": epoch_time} + epoch_stats.update(self.keep_avg_train.avg_values) + self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) + if self.config.tb_model_param_stats: + self.tb_logger.tb_model_weights(self.model, self.total_steps_done) + + @staticmethod + def _model_eval_step( + batch: Dict, model: nn.Module, criterion: nn.Module, optimizer_idx: int = None + ) -> Tuple[Dict, Dict]: + """ + Perform a evaluation forward pass. Compute model outputs and losses with no gradients. + + Args: + batch (Dict): IBatch of inputs. + model (nn.Module): Model to call evaluation. + criterion (nn.Module): Model criterion. + optimizer_idx (int, optional): Optimizer ID to define the closure in multi-optimizer training. Defaults to None. + + Returns: + Tuple[Dict, Dict]: model outputs and losses. + """ + input_args = [batch, criterion] + if optimizer_idx is not None: + input_args.append(optimizer_idx) + if hasattr(model, "module"): + return model.module.eval_step(*input_args) + return model.eval_step(*input_args) + + def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: + with torch.no_grad(): + outputs_per_optimizer = None + loss_dict = {} + if not isinstance(self.optimizer, list): + outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion) + else: + outputs_per_optimizer = [None] * len(self.optimizer) + for idx, _ in enumerate(self.optimizer): + criterion = self.criterion + outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx) + outputs_per_optimizer[idx] = outputs + if loss_dict_new is not None: + loss_dict.update(loss_dict_new) + outputs = outputs_per_optimizer + + # update avg stats + update_eval_values = dict() + for key, value in loss_dict.items(): + update_eval_values["avg_" + key] = value + self.keep_avg_eval.update_values(update_eval_values) + + if self.config.print_eval: + self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values) + return outputs, loss_dict + + def eval_epoch(self) -> None: + self.eval_loader = ( + self.get_eval_dataloader( + self.ap, + self.data_eval, + verbose=True, + ) + if self.config.run_eval + else None + ) + + self.model.eval() + self.c_logger.print_eval_start() + loader_start_time = time.time() + batch = None + for cur_step, batch in enumerate(self.eval_loader): + # format data + batch = self.format_batch(batch) + loader_time = time.time() - loader_start_time + self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) + outputs, _ = self.eval_step(batch, cur_step) + # plot epoch stats, artifacts and figures + if self.args.rank == 0: + figures, audios = None, None + if hasattr(self.model, "module") and hasattr(self.model.module, "eval_log"): + figures, audios = self.model.module.eval_log(self.ap, batch, outputs) + elif hasattr(self.model, "eval_log"): + figures, audios = self.model.eval_log(self.ap, batch, outputs) + if figures is not None: + self.tb_logger.tb_eval_figures(self.total_steps_done, figures) + if audios is not None: + self.tb_logger.tb_eval_audios(self.total_steps_done, audios, self.ap.sample_rate) + + def test_run(self) -> None: + """Run test and log the results. Test run must be defined by the model. + Model must return figures and audios to be logged by the Tensorboard logger.""" + if hasattr(self.model, "test_run"): + if hasattr(self.eval_loader.load_test_samples): + samples = self.eval_loader.load_test_samples(1) + figures, audios = self.model.test_run(samples) + else: + figures, audios = self.model.test_run() + self.tb_logger.tb_test_audios(self.total_steps_done, audios, self.config.audio["sample_rate"]) + self.tb_logger.tb_test_figures(self.total_steps_done, figures) + + def _fit(self) -> None: + """🏃 train -> evaluate -> test for the number of epochs.""" + if self.restore_step != 0 or self.args.best_path: + print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") + self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] + print(f" > Starting with loaded last best loss {self.best_loss}.") + + self.total_steps_done = self.restore_step + + for epoch in range(0, self.config.epochs): + self.callbacks.on_epoch_start() + self.keep_avg_train = KeepAverage() + self.keep_avg_eval = KeepAverage() if self.config.run_eval else None + self.epochs_done = epoch + self.c_logger.print_epoch_start(epoch, self.config.epochs) + self.train_epoch() + if self.config.run_eval: + self.eval_epoch() + if epoch >= self.config.test_delay_epochs and self.args.rank < 0: + self.test_run() + self.c_logger.print_epoch_end( + epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values + ) + self.save_best_model() + self.callbacks.on_epoch_end() + + def fit(self) -> None: + """Where the ✨️magic✨️ happens...""" + try: + self._fit() + except KeyboardInterrupt: + self.callbacks.on_keyboard_interrupt() + # if the output folder is empty remove the run. + remove_experiment_folder(self.output_path) + # stop without error signal + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except BaseException: # pylint: disable=broad-except + remove_experiment_folder(self.output_path) + traceback.print_exc() + sys.exit(1) + + def save_best_model(self) -> None: + """Save the best model. It only saves if the current target loss is smaller then the previous.""" + self.best_loss = save_best_model( + self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], + self.best_loss, + self.config, + self.model, + self.optimizer, + self.scaler if self.use_amp_scaler else None, + self.total_steps_done, + self.epochs_done, + self.output_path, + keep_all_best=self.config.keep_all_best, + keep_after=self.config.keep_after, + ) + + @staticmethod + def _setup_logger_config(log_file: str) -> None: + logging.basicConfig( + level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()] + ) @staticmethod def _is_apex_available(): return importlib.util.find_spec("apex") is not None @staticmethod - @abstractmethod - def get_model(*args, **kwargs) -> nn.Module: - pass + def get_optimizer(model: nn.Module, config: Coqpit) -> Union[torch.optim.Optimizer, List]: + if hasattr(model, "get_optimizer"): + optimizer = model.get_optimizer() + if optimizer is None: + optimizer_name = config.optimizer + optimizer_params = config.optimizer_params + return get_optimizer(optimizer_name, optimizer_params, config.lr, model) + return optimizer @staticmethod - @abstractmethod - def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer: - pass + def get_lr(model: nn.Module, config: Coqpit) -> Union[float, List[float]]: + lr = None + if hasattr(model, "get_lr"): + lr = model.get_lr() + if lr is None: + lr = config.lr + return lr @staticmethod - @abstractmethod def get_scheduler( - config: Coqpit, optimizer: torch.optim.Optimizer - ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access - pass + model: nn.Module, config: Coqpit, optimizer: Union[torch.optim.Optimizer, List] + ) -> Union[torch.optim.lr_scheduler._LRScheduler, List]: # pylint: disable=protected-access + scheduler = None + if hasattr(model, "get_scheduler"): + scheduler = model.get_scheduler(optimizer) + if scheduler is None: + lr_scheduler = config.lr_scheduler + lr_scheduler_params = config.lr_scheduler_params + return get_scheduler(lr_scheduler, lr_scheduler_params, optimizer) + return scheduler @staticmethod - @abstractmethod - def get_criterion(config: Coqpit) -> nn.Module: - pass + def get_criterion(model: nn.Module) -> nn.Module: + criterion = None + criterion = model.get_criterion() + return criterion - @abstractmethod - def restore_model(self, *args, **kwargs) -> Tuple: - pass - @abstractmethod - def get_train_dataloader(self, *args, **kwargs) -> _DataLoader: - pass +def init_arguments(): + train_config = TrainingArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser - @abstractmethod - def get_eval_dataloder(self, *args, **kwargs) -> _DataLoader: - pass - @abstractmethod - def format_batch(self, batch: List) -> Dict: - pass +def get_last_checkpoint(path): + """Get latest checkpoint or/and best model in path. - @abstractmethod - def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - pass + It is based on globbing for `*.pth.tar` and the RegEx + `(checkpoint|best_model)_([0-9]+)`. - @abstractmethod - def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: - pass + Args: + path (list): Path to files to be compared. - @abstractmethod - def train_epoch(self) -> None: - pass + Raises: + ValueError: If no checkpoint or best_model files are found. - @abstractmethod - def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]: - pass + Returns: + last_checkpoint (str): Last checkpoint filename. + """ + file_names = glob.glob(os.path.join(path, "*.pth.tar")) + last_models = {} + last_model_nums = {} + for key in ["checkpoint", "best_model"]: + last_model_num = None + last_model = None + # pass all the checkpoint files and find + # the one with the largest model number suffix. + for file_name in file_names: + match = re.search(f"{key}_([0-9]+)", file_name) + if match is not None: + model_num = int(match.groups()[0]) + if last_model_num is None or model_num > last_model_num: + last_model_num = model_num + last_model = file_name - @abstractmethod - def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: - pass + # if there is not checkpoint found above + # find the checkpoint with the latest + # modification date. + key_file_names = [fn for fn in file_names if key in fn] + if last_model is None and len(key_file_names) > 0: + last_model = max(key_file_names, key=os.path.getctime) + last_model_num = torch.load(last_model)["step"] - @abstractmethod - def eval_epoch(self) -> None: - pass + if last_model is not None: + last_models[key] = last_model + last_model_nums[key] = last_model_num - @abstractmethod - def test_run(self) -> None: - pass + # check what models were found + if not last_models: + raise ValueError(f"No models found in continue path {path}!") + if "checkpoint" not in last_models: # no checkpoint just best model + last_models["checkpoint"] = last_models["best_model"] + elif "best_model" not in last_models: # no best model + # this shouldn't happen, but let's handle it just in case + last_models["best_model"] = None + # finally check if last best model is more recent than checkpoint + elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: + last_models["checkpoint"] = last_models["best_model"] - @abstractmethod - def fit(self) -> None: - pass + return last_models["checkpoint"], last_models["best_model"] - @abstractmethod - def save_best_model(self) -> None: - pass - @abstractmethod - def on_epoch_start(self) -> None: - pass +def process_args(args, config=None): + """Process parsed comand line arguments. - @abstractmethod - def on_epoch_end(self) -> None: - pass + Args: + args (argparse.Namespace or dict like): Parsed input arguments. + config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. - @abstractmethod - def on_train_step_start(self) -> None: - pass + Returns: + c (TTS.utils.io.AttrDict): Config paramaters. + out_path (str): Path to save models and logging. + audio_path (str): Path to save generated test audios. + c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does + logging to the console. + tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does + the TensorBoard loggind. + """ + if isinstance(args, tuple): + args, coqpit_overrides = args + if args.continue_path: + # continue a previous training from its output folder + experiment_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + args.restore_path, best_model = get_last_checkpoint(args.continue_path) + if not args.best_path: + args.best_path = best_model + # setup output paths and read configs + if config is None: + config = load_config(args.config_path) + # override values from command-line args + config.parse_known_args(coqpit_overrides, relaxed_parser=True) + if config.mixed_precision: + print(" > Mixed precision mode is ON") + experiment_path = args.continue_path + if not experiment_path: + experiment_path = create_experiment_folder(config.output_path, config.run_name) + audio_path = os.path.join(experiment_path, "test_audios") + # setup rank 0 process in distributed training + tb_logger = None + if args.rank == 0: + os.makedirs(audio_path, exist_ok=True) + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + # if model characters are not set in the config file + # save the default set to the config file for future + # compatibility. + if config.has("characters_config"): + used_characters = parse_symbols() + new_fields["characters"] = used_characters + copy_model_files(config, experiment_path, new_fields) + os.chmod(audio_path, 0o775) + os.chmod(experiment_path, 0o775) + tb_logger = TensorboardLogger(experiment_path, model_name=config.model) + # write model desc to tensorboard + tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) + c_logger = ConsoleLogger() + return config, experiment_path, audio_path, c_logger, tb_logger - @abstractmethod - def on_train_step_end(self) -> None: - pass + +def init_training(argv: Union[List, Coqpit], config: Coqpit = None): + """Initialization of a training run.""" + if isinstance(argv, Coqpit): + parser = argv.init_argparse(arg_prefix="") + else: + parser = init_arguments() + args = parser.parse_known_args() + config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, config) + return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py deleted file mode 100644 index 705ea5bc..00000000 --- a/TTS/tts/models/tacotron_abstract.py +++ /dev/null @@ -1,245 +0,0 @@ -import copy -from abc import ABC, abstractmethod -from typing import Dict - -import torch -from torch import nn - -from TTS.tts.utils.data import sequence_mask -from TTS.utils.generic_utils import format_aux_input -from TTS.utils.training import gradual_training_scheduler - - -class TacotronAbstract(ABC, nn.Module): - def __init__( - self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type="original", - attn_win=False, - attn_norm="softmax", - prenet_type="original", - prenet_dropout=True, - prenet_dropout_at_inference=False, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=512, - decoder_in_features=512, - d_vector_dim=None, - use_gst=False, - gst=None, - gradual_training=None, - ): - """Abstract Tacotron class""" - super().__init__() - self.num_chars = num_chars - self.r = r - self.decoder_output_dim = decoder_output_dim - self.postnet_output_dim = postnet_output_dim - self.use_gst = use_gst - self.gst = gst - self.num_speakers = num_speakers - self.bidirectional_decoder = bidirectional_decoder - self.double_decoder_consistency = double_decoder_consistency - self.ddc_r = ddc_r - self.attn_type = attn_type - self.attn_win = attn_win - self.attn_norm = attn_norm - self.prenet_type = prenet_type - self.prenet_dropout = prenet_dropout - self.prenet_dropout_at_inference = prenet_dropout_at_inference - self.forward_attn = forward_attn - self.trans_agent = trans_agent - self.forward_attn_mask = forward_attn_mask - self.location_attn = location_attn - self.attn_K = attn_K - self.separate_stopnet = separate_stopnet - self.encoder_in_features = encoder_in_features - self.decoder_in_features = decoder_in_features - self.d_vector_dim = d_vector_dim - self.gradual_training = gradual_training - - # layers - self.embedding = None - self.encoder = None - self.decoder = None - self.postnet = None - - # multispeaker - if self.d_vector_dim is None: - # if d_vector_dim is None we need use the nn.Embedding, with default d_vector_dim - self.use_d_vectors = False - else: - # if d_vector_dim is not None we need use speaker embedding per sample - self.use_d_vectors = True - - # global style token - if self.gst and use_gst: - self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim - self.gst_layer = None - - # model states - self.embedded_speakers = None - self.embedded_speakers_projected = None - - # additional layers - self.decoder_backward = None - self.coarse_decoder = None - - @staticmethod - def _format_aux_input(aux_input: Dict) -> Dict: - return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) - - ############################# - # INIT FUNCTIONS - ############################# - - def _init_states(self): - self.embedded_speakers = None - self.embedded_speakers_projected = None - - def _init_backward_decoder(self): - self.decoder_backward = copy.deepcopy(self.decoder) - - def _init_coarse_decoder(self): - self.coarse_decoder = copy.deepcopy(self.decoder) - self.coarse_decoder.r_init = self.ddc_r - self.coarse_decoder.set_r(self.ddc_r) - - ############################# - # CORE FUNCTIONS - ############################# - - @abstractmethod - def forward(self): - pass - - @abstractmethod - def inference(self): - pass - - def load_checkpoint( - self, config, checkpoint_path, eval=False - ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) - self.load_state_dict(state["model"]) - self.decoder.set_r(state["r"]) - if eval: - self.eval() - assert not self.training - - ############################# - # COMMON COMPUTE FUNCTIONS - ############################# - - def compute_masks(self, text_lengths, mel_lengths): - """Compute masks against sequence paddings.""" - # B x T_in_max (boolean) - input_mask = sequence_mask(text_lengths) - output_mask = None - if mel_lengths is not None: - max_len = mel_lengths.max() - r = self.decoder.r - max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len - output_mask = sequence_mask(mel_lengths, max_len=max_len) - return input_mask, output_mask - - def _backward_pass(self, mel_specs, encoder_outputs, mask): - """Run backwards decoder""" - decoder_outputs_b, alignments_b, _ = self.decoder_backward( - encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask - ) - decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() - return decoder_outputs_b, alignments_b - - def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): - """Double Decoder Consistency""" - T = mel_specs.shape[1] - if T % self.coarse_decoder.r > 0: - padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) - mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) - decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( - encoder_outputs.detach(), mel_specs, input_mask - ) - # scale_factor = self.decoder.r_init / self.decoder.r - alignments_backward = torch.nn.functional.interpolate( - alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest" - ).transpose(1, 2) - decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) - decoder_outputs_backward = decoder_outputs_backward[:, :T, :] - return decoder_outputs_backward, alignments_backward - - ############################# - # EMBEDDING FUNCTIONS - ############################# - - def compute_speaker_embedding(self, speaker_ids): - """Compute speaker embedding vectors""" - if hasattr(self, "speaker_embedding") and speaker_ids is None: - raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") - if hasattr(self, "speaker_embedding") and speaker_ids is not None: - self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1) - if hasattr(self, "speaker_project_mel") and speaker_ids is not None: - self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1) - - def compute_gst(self, inputs, style_input, speaker_embedding=None): - """Compute global style token""" - if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) - if speaker_embedding is not None: - query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) - - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) - gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) - for k_token, v_amplifier in style_input.items(): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * v_amplifier - elif style_input is None: - gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) - else: - gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable - inputs = self._concat_speaker_embedding(inputs, gst_outputs) - return inputs - - @staticmethod - def _add_speaker_embedding(outputs, embedded_speakers): - embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) - outputs = outputs + embedded_speakers_ - return outputs - - @staticmethod - def _concat_speaker_embedding(outputs, embedded_speakers): - embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) - outputs = torch.cat([outputs, embedded_speakers_], dim=-1) - return outputs - - ############################# - # CALLBACKS - ############################# - - def on_epoch_start(self, trainer): - """Callback for setting values wrt gradual training schedule. - - Args: - trainer (TrainerTTS): TTS trainer object that is used to train this model. - """ - if self.gradual_training: - r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) - trainer.config.r = r - self.decoder.set_r(r) - if trainer.config.bidirectional_decoder: - trainer.model.decoder_backward.set_r(r) - trainer.train_loader = trainer.setup_train_dataloader(self.ap, self.model.decoder.r, verbose=True) - trainer.eval_loader = trainer.setup_eval_dataloder(self.ap, self.model.decoder.r) - print(f"\n > Number of output frames: {self.decoder.r}") diff --git a/TTS/tts/trainer_tts.py b/TTS/tts/trainer_tts.py deleted file mode 100644 index 6c900120..00000000 --- a/TTS/tts/trainer_tts.py +++ /dev/null @@ -1,709 +0,0 @@ -# -*- coding: utf-8 -*- - -import importlib -import logging -import os -import time -from argparse import Namespace -from typing import Dict, List, Tuple, Union - -import torch -from coqpit import Coqpit - -# DISTRIBUTED -from torch import nn -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from TTS.trainer import TrainerAbstract -from TTS.tts.datasets import TTSDataset, load_meta_data -from TTS.tts.layers import setup_loss -from TTS.tts.models import setup_model -from TTS.tts.utils.io import save_best_model, save_checkpoint -from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import make_symbols -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict, to_cuda -from TTS.utils.logging import ConsoleLogger, TensorboardLogger -from TTS.utils.training import check_update, setup_torch_training_env - - -# pylint: disable=import-outside-toplevel, too-many-public-methods - -class TrainerTTS(TrainerAbstract): - use_cuda, num_gpus = setup_torch_training_env(True, False) - - def __init__( - self, - args: Union[Coqpit, Namespace], - config: Coqpit, - c_logger: ConsoleLogger = None, - tb_logger: TensorboardLogger = None, - model: nn.Module = None, - output_path: str = None, - ) -> None: - self.args = args - self.config = config - self.c_logger = ConsoleLogger() if c_logger is None else c_logger - if tb_logger is None: - self.tb_logger = TensorboardLogger(output_path, model_name=config.model) - self.tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) - else: - self.tb_logger = tb_logger - self.output_path = output_path - - self.total_steps_done = 0 - self.epochs_done = 0 - self.restore_step = 0 - self.best_loss = float("inf") - self.train_loader = None - self.eval_loader = None - self.output_audio_path = os.path.join(output_path, "test_audios") - - self.keep_avg_train = None - self.keep_avg_eval = None - - log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") - self._setup_logger_config(log_file) - - # model, audio processor, datasets, loss - # init audio processor - self.ap = AudioProcessor(**self.config.audio.to_dict()) - - # init character processor - self.model_characters = self.get_character_processor(self.config) - - # load dataset samples - self.data_train, self.data_eval = load_meta_data(self.config.datasets) - - # default speaker manager - self.speaker_manager = self.get_speaker_manager(self.config, args.restore_path, output_path, self.data_train) - - # init TTS model - if model is not None: - self.model = model - else: - self.model = self.get_model( - len(self.model_characters), - self.speaker_manager.num_speakers, - self.config, - self.speaker_manager.d_vector_dim if self.speaker_manager.d_vectors else None, - ) - - # setup criterion - self.criterion = self.get_criterion(self.config) - - # DISTRUBUTED - if self.num_gpus > 1: - init_distributed( - args.rank, - self.num_gpus, - args.group_id, - self.config.distributed_backend, - self.config.distributed_url, - ) - - if self.use_cuda: - self.model.cuda() - self.criterion.cuda() - - # scalers for mixed precision training - self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None - - # setup optimizer - self.optimizer = self.get_optimizer(self.model, self.config) - - if self.args.restore_path: - self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model( - self.config, args.restore_path, self.model, self.optimizer, self.scaler - ) - - # setup scheduler - self.scheduler = self.get_scheduler(self.config, self.optimizer) - - # DISTRUBUTED - if self.num_gpus > 1: - self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) - - # count model size - num_params = count_parameters(self.model) - print("\n > Model has {} parameters".format(num_params)) - - @staticmethod - def get_model(num_chars: int, num_speakers: int, config: Coqpit, d_vector_dim: int) -> nn.Module: - model = setup_model(num_chars, num_speakers, config, d_vector_dim) - return model - - @staticmethod - def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer: - optimizer_name = config.optimizer - optimizer_params = config.optimizer_params - if optimizer_name.lower() == "radam": - module = importlib.import_module("TTS.utils.radam") - optimizer = getattr(module, "RAdam") - else: - optimizer = getattr(torch.optim, optimizer_name) - return optimizer(model.parameters(), lr=config.lr, **optimizer_params) - - @staticmethod - def get_character_processor(config: Coqpit) -> str: - # setup custom characters if set in config file. - # TODO: implement CharacterProcessor - if config.characters is not None: - symbols, phonemes = make_symbols(**config.characters.to_dict()) - else: - from TTS.tts.utils.text.symbols import phonemes, symbols - model_characters = phonemes if config.use_phonemes else symbols - return model_characters - - @staticmethod - def get_speaker_manager( - config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = None - ) -> SpeakerManager: - speaker_manager = get_speaker_manager(config, restore_path, data_train, out_path) - return speaker_manager - - @staticmethod - def get_scheduler( - config: Coqpit, optimizer: torch.optim.Optimizer - ) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access - lr_scheduler = config.lr_scheduler - lr_scheduler_params = config.lr_scheduler_params - if lr_scheduler is None: - return None - if lr_scheduler.lower() == "noamlr": - from TTS.utils.training import NoamLR - - scheduler = NoamLR - else: - scheduler = getattr(torch.optim, lr_scheduler) - return scheduler(optimizer, **lr_scheduler_params) - - @staticmethod - def get_criterion(config: Coqpit) -> nn.Module: - return setup_loss(config) - - def restore_model( - self, - config: Coqpit, - restore_path: str, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scaler: torch.cuda.amp.GradScaler = None, - ) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]: - print(" > Restoring from %s ..." % os.path.basename(restore_path)) - checkpoint = torch.load(restore_path) - try: - print(" > Restoring Model...") - model.load_state_dict(checkpoint["model"]) - print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint["optimizer"]) - if "scaler" in checkpoint and config.mixed_precision: - print(" > Restoring AMP Scaler...") - scaler.load_state_dict(checkpoint["scaler"]) - except (KeyError, RuntimeError): - print(" > Partial model initialization...") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], config) - model.load_state_dict(model_dict) - del model_dict - - for group in optimizer.param_groups: - group["lr"] = self.config.lr - print( - " > Model restored from step %d" % checkpoint["step"], - ) - restore_step = checkpoint["step"] - return model, optimizer, scaler, restore_step - - def _get_loader( - self, - r: int, - ap: AudioProcessor, - is_eval: bool, - data_items: List, - verbose: bool, - speaker_ids: Union[Dict, List], - d_vectors: Union[Dict, List], - ) -> DataLoader: - if is_eval and not self.config.run_eval: - loader = None - else: - dataset = TTSDataset( - outputs_per_step=r, - text_cleaner=self.config.text_cleaner, - compute_linear_spec=self.config.model.lower() == "tacotron", - meta_data=data_items, - ap=ap, - tp=self.config.characters, - add_blank=self.config["add_blank"], - batch_group_size=0 if is_eval else self.config.batch_group_size * self.config.batch_size, - min_seq_len=self.config.min_seq_len, - max_seq_len=self.config.max_seq_len, - phoneme_cache_path=self.config.phoneme_cache_path, - use_phonemes=self.config.use_phonemes, - phoneme_language=self.config.phoneme_language, - enable_eos_bos=self.config.enable_eos_bos_chars, - use_noise_augment=not is_eval, - verbose=verbose, - speaker_id_mapping=speaker_ids if self.config.use_speaker_embedding else None, - d_vector_mapping=d_vectors - if self.config.use_speaker_embedding and self.config.use_external_speaker_embedding_file - else None, - ) - - if self.config.use_phonemes and self.config.compute_input_seq_cache: - # precompute phonemes to have a better estimate of sequence lengths. - dataset.compute_input_seq(self.config.num_loader_workers) - dataset.sort_items() - - sampler = DistributedSampler(dataset) if self.num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=self.config.eval_batch_size if is_eval else self.config.batch_size, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - sampler=sampler, - num_workers=self.config.num_val_loader_workers if is_eval else self.config.num_loader_workers, - pin_memory=False, - ) - return loader - - def get_train_dataloader( - self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict - ) -> DataLoader: - return self._get_loader(r, ap, False, data_items, verbose, speaker_ids, d_vectors) - - def get_eval_dataloder( - self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_ids: Dict, d_vectors: Dict - ) -> DataLoader: - return self._get_loader(r, ap, True, data_items, verbose, speaker_ids, d_vectors) - - def format_batch(self, batch: List) -> Dict: - # setup input batch - text_input = batch[0] - text_lengths = batch[1] - speaker_names = batch[2] - linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None - mel_input = batch[4] - mel_lengths = batch[5] - stop_targets = batch[6] - item_idx = batch[7] - d_vectors = batch[8] - speaker_ids = batch[9] - attn_mask = batch[10] - max_text_length = torch.max(text_lengths.float()) - max_spec_length = torch.max(mel_lengths.float()) - - # compute durations from attention masks - durations = None - if attn_mask is not None: - durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) - for idx, am in enumerate(attn_mask): - # compute raw durations - c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] - # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) - c_idxs, counts = torch.unique(c_idxs, return_counts=True) - dur = torch.ones([text_lengths[idx]]).to(counts.dtype) - dur[c_idxs] = counts - # smooth the durations and set any 0 duration to 1 - # by cutting off from the largest duration indeces. - extra_frames = dur.sum() - mel_lengths[idx] - largest_idxs = torch.argsort(-dur)[:extra_frames] - dur[largest_idxs] -= 1 - assert ( - dur.sum() == mel_lengths[idx] - ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" - durations[idx, : text_lengths[idx]] = dur - - # set stop targets view, we predict a single stop token per iteration. - stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) - - # dispatch batch to GPU - if self.use_cuda: - text_input = to_cuda(text_input) - text_lengths = to_cuda(text_lengths) - mel_input = to_cuda(mel_input) - mel_lengths = to_cuda(mel_lengths) - linear_input = to_cuda(linear_input) if self.config.model.lower() in ["tacotron"] else None - stop_targets = to_cuda(stop_targets) - attn_mask = to_cuda(attn_mask) if attn_mask is not None else None - durations = to_cuda(durations) if attn_mask is not None else None - if speaker_ids is not None: - speaker_ids = to_cuda(speaker_ids) - if d_vectors is not None: - d_vectors = to_cuda(d_vectors) - - return { - "text_input": text_input, - "text_lengths": text_lengths, - "speaker_names": speaker_names, - "mel_input": mel_input, - "mel_lengths": mel_lengths, - "linear_input": linear_input, - "stop_targets": stop_targets, - "attn_mask": attn_mask, - "durations": durations, - "speaker_ids": speaker_ids, - "d_vectors": d_vectors, - "max_text_length": max_text_length, - "max_spec_length": max_spec_length, - "item_idx": item_idx, - } - - def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: - if hasattr(self.model, "module"): - return self.model.module.train_step(batch, criterion) - return self.model.train_step(batch, criterion) - - def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: - self.on_train_step_start() - step_start_time = time.time() - - # format data - batch = self.format_batch(batch) - loader_time = time.time() - loader_start_time - - # zero-out optimizer - self.optimizer.zero_grad() - - with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): - outputs, loss_dict = self._train_step(batch, self.criterion) - - # check nan loss - if torch.isnan(loss_dict["loss"]).any(): - raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.") - - # optimizer step - if self.config.mixed_precision: - # model optimizer step in mixed precision mode - self.scaler.scale(loss_dict["loss"]).backward() - self.scaler.unscale_(self.optimizer) - grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True) - self.scaler.step(self.optimizer) - self.scaler.update() - else: - # main model optimizer step - loss_dict["loss"].backward() - grad_norm, _ = check_update(self.model, self.config.grad_clip, ignore_stopnet=True) - self.optimizer.step() - - step_time = time.time() - step_start_time - - # setup lr - if self.config.lr_scheduler: - self.scheduler.step() - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - self.keep_avg_train.update_values(update_train_values) - - # print training progress - current_lr = self.optimizer.param_groups[0]["lr"] - if self.total_steps_done % self.config.print_step == 0: - log_dict = { - "max_spec_length": [batch["max_spec_length"], 1], # value, precision - "max_text_length": [batch["max_text_length"], 1], - "step_time": [step_time, 4], - "loader_time": [loader_time, 2], - "current_lr": current_lr, - } - self.c_logger.print_train_step( - batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values - ) - - if self.args.rank == 0: - # Plot Training Iter Stats - # reduce TB load - if self.total_steps_done % self.config.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time, - } - iter_stats.update(loss_dict) - self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats) - - if self.total_steps_done % self.config.save_step == 0: - if self.config.checkpoint: - # save model - save_checkpoint( - self.model, - self.optimizer, - self.total_steps_done, - self.epochs_done, - self.config.r, - self.output_path, - model_loss=loss_dict["loss"], - characters=self.model_characters, - scaler=self.scaler.state_dict() if self.config.mixed_precision else None, - ) - # training visualizations - if hasattr(self.model, "module"): - figures, audios = self.model.module.train_log(self.ap, batch, outputs) - else: - figures, audios = self.model.train_log(self.ap, batch, outputs) - self.tb_logger.tb_train_figures(self.total_steps_done, figures) - self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate) - self.total_steps_done += 1 - self.on_train_step_end() - return outputs, loss_dict - - def train_epoch(self) -> None: - self.model.train() - epoch_start_time = time.time() - if self.use_cuda: - batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus)) - else: - batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) - self.c_logger.print_train_start() - for cur_step, batch in enumerate(self.train_loader): - loader_start_time = time.time() - _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) - epoch_time = time.time() - epoch_start_time - # Plot self.epochs_done Stats - if self.args.rank == 0: - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(self.keep_avg_train.avg_values) - self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats) - if self.config.tb_model_param_stats: - self.tb_logger.tb_model_weights(self.model, self.total_steps_done) - - def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]: - if hasattr(self.model, "module"): - return self.model.module.eval_step(batch, self.criterion) - return self.model.eval_step(batch, self.criterion) - - def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: - with torch.no_grad(): - step_start_time = time.time() - - with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): - outputs, loss_dict = self._eval_step(batch) - - step_time = time.time() - step_start_time - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_eval_values = dict() - for key, value in loss_dict.items(): - update_eval_values["avg_" + key] = value - update_eval_values["avg_step_time"] = step_time - self.keep_avg_eval.update_values(update_eval_values) - - if self.config.print_eval: - self.c_logger.print_eval_step(step, loss_dict, self.keep_avg_eval.avg_values) - return outputs, loss_dict - - def eval_epoch(self) -> None: - self.model.eval() - self.c_logger.print_eval_start() - loader_start_time = time.time() - batch = None - for cur_step, batch in enumerate(self.eval_loader): - # format data - batch = self.format_batch(batch) - loader_time = time.time() - loader_start_time - self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) - outputs, _ = self.eval_step(batch, cur_step) - # Plot epoch stats and samples from the last batch. - if self.args.rank == 0: - if hasattr(self.model, "module"): - figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs) - else: - figures, eval_audios = self.model.eval_log(self.ap, batch, outputs) - self.tb_logger.tb_eval_figures(self.total_steps_done, figures) - self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate) - - def test_run( - self, - ) -> None: - print(" | > Synthesizing test sentences.") - test_audios = {} - test_figures = {} - test_sentences = self.config.test_sentences - aux_inputs = self._get_aux_inputs() - for idx, sen in enumerate(test_sentences): - wav, alignment, model_outputs, _ = synthesis( - self.model, - sen, - self.config, - self.use_cuda, - self.ap, - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, - use_griffin_lim=True, - do_trim_silence=False, - ).values() - - file_path = os.path.join(self.output_audio_path, str(self.total_steps_done)) - os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) - self.ap.save_wav(wav, file_path) - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False) - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) - - self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"]) - self.tb_logger.tb_test_figures(self.total_steps_done, test_figures) - - def _get_aux_inputs(self) -> Dict: - # setup speaker_id - speaker_id = 0 if self.config.use_speaker_embedding else None - # setup d_vector - d_vector = ( - self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0]) - if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding - else None - ) - # setup style_mel - if self.config.has("gst_style_input"): - style_wav = self.config.gst_style_input - else: - style_wav = None - if style_wav is None and "use_gst" in self.config and self.config.use_gst: - # inicialize GST with zero dict. - style_wav = {} - print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") - for i in range(self.config.gst["gst_num_style_tokens"]): - style_wav[str(i)] = 0 - aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} - return aux_inputs - - def fit(self) -> None: - if self.restore_step != 0 or self.args.best_path: - print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...") - self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with loaded last best loss {self.best_loss}.") - - # define data loaders - self.train_loader = self.get_train_dataloader( - self.config.r, - self.ap, - self.data_train, - verbose=True, - speaker_ids=self.speaker_manager.speaker_ids, - d_vectors=self.speaker_manager.d_vectors, - ) - self.eval_loader = ( - self.get_eval_dataloder( - self.config.r, - self.ap, - self.data_train, - verbose=True, - speaker_ids=self.speaker_manager.speaker_ids, - d_vectors=self.speaker_manager.d_vectors, - ) - if self.config.run_eval - else None - ) - - self.total_steps_done = self.restore_step - - for epoch in range(0, self.config.epochs): - self.on_epoch_start() - self.keep_avg_train = KeepAverage() - self.keep_avg_eval = KeepAverage() if self.config.run_eval else None - self.epochs_done = epoch - self.c_logger.print_epoch_start(epoch, self.config.epochs) - self.train_epoch() - if self.config.run_eval: - self.eval_epoch() - if epoch >= self.config.test_delay_epochs and self.args.rank < 0: - self.test_run() - self.c_logger.print_epoch_end( - epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values - ) - self.save_best_model() - self.on_epoch_end() - - def save_best_model(self) -> None: - self.best_loss = save_best_model( - self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"], - self.best_loss, - self.model, - self.optimizer, - self.total_steps_done, - self.epochs_done, - self.config.r, - self.output_path, - self.model_characters, - keep_all_best=self.config.keep_all_best, - keep_after=self.config.keep_after, - scaler=self.scaler.state_dict() if self.config.mixed_precision else None, - ) - - @staticmethod - def _setup_logger_config(log_file: str) -> None: - logging.basicConfig( - level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()] - ) - - def on_epoch_start(self) -> None: # pylint: disable=no-self-use - if hasattr(self.model, "on_epoch_start"): - self.model.on_epoch_start(self) - - if hasattr(self.criterion, "on_epoch_start"): - self.criterion.on_epoch_start(self) - - if hasattr(self.optimizer, "on_epoch_start"): - self.optimizer.on_epoch_start(self) - - def on_epoch_end(self) -> None: # pylint: disable=no-self-use - if hasattr(self.model, "on_epoch_end"): - self.model.on_epoch_end(self) - - if hasattr(self.criterion, "on_epoch_end"): - self.criterion.on_epoch_end(self) - - if hasattr(self.optimizer, "on_epoch_end"): - self.optimizer.on_epoch_end(self) - - def on_train_step_start(self) -> None: # pylint: disable=no-self-use - if hasattr(self.model, "on_train_step_start"): - self.model.on_train_step_start(self) - - if hasattr(self.criterion, "on_train_step_start"): - self.criterion.on_train_step_start(self) - - if hasattr(self.optimizer, "on_train_step_start"): - self.optimizer.on_train_step_start(self) - - def on_train_step_end(self) -> None: # pylint: disable=no-self-use - if hasattr(self.model, "on_train_step_end"): - self.model.on_train_step_end(self) - - if hasattr(self.criterion, "on_train_step_end"): - self.criterion.on_train_step_end(self) - - if hasattr(self.optimizer, "on_train_step_end"): - self.optimizer.on_train_step_end(self) diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py deleted file mode 100644 index 9d92ae82..00000000 --- a/TTS/utils/arguments.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -"""Argument parser for training scripts.""" - -import argparse -import glob -import os -import re - -import torch - -from TTS.config import load_config -from TTS.tts.utils.text.symbols import parse_symbols -from TTS.utils.generic_utils import create_experiment_folder, get_git_branch -from TTS.utils.io import copy_model_files -from TTS.utils.logging import ConsoleLogger, TensorboardLogger - - -def init_arguments(argv): - """Parse command line arguments of training scripts. - - Args: - argv (list): This is a list of input arguments as given by sys.argv - - Returns: - argparse.Namespace: Parsed arguments. - """ - parser = argparse.ArgumentParser() - parser.add_argument( - "--continue_path", - type=str, - help=( - "Training output folder to continue training. Used to continue " - "a training. If it is used, 'config_path' is ignored." - ), - default="", - required="--config_path" not in argv, - ) - parser.add_argument( - "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="" - ) - parser.add_argument( - "--best_path", - type=str, - help=( - "Best model file to be used for extracting best loss." - "If not specified, the latest best model in continue path is used" - ), - default="", - ) - parser.add_argument( - "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv - ) - parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.") - parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.") - parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.") - - return parser - - -def get_last_checkpoint(path): - """Get latest checkpoint or/and best model in path. - - It is based on globbing for `*.pth.tar` and the RegEx - `(checkpoint|best_model)_([0-9]+)`. - - Args: - path (list): Path to files to be compared. - - Raises: - ValueError: If no checkpoint or best_model files are found. - - Returns: - last_checkpoint (str): Last checkpoint filename. - """ - file_names = glob.glob(os.path.join(path, "*.pth.tar")) - last_models = {} - last_model_nums = {} - for key in ["checkpoint", "best_model"]: - last_model_num = None - last_model = None - # pass all the checkpoint files and find - # the one with the largest model number suffix. - for file_name in file_names: - match = re.search(f"{key}_([0-9]+)", file_name) - if match is not None: - model_num = int(match.groups()[0]) - if last_model_num is None or model_num > last_model_num: - last_model_num = model_num - last_model = file_name - - # if there is not checkpoint found above - # find the checkpoint with the latest - # modification date. - key_file_names = [fn for fn in file_names if key in fn] - if last_model is None and len(key_file_names) > 0: - last_model = max(key_file_names, key=os.path.getctime) - last_model_num = torch.load(last_model)["step"] - - if last_model is not None: - last_models[key] = last_model - last_model_nums[key] = last_model_num - - # check what models were found - if not last_models: - raise ValueError(f"No models found in continue path {path}!") - if "checkpoint" not in last_models: # no checkpoint just best model - last_models["checkpoint"] = last_models["best_model"] - elif "best_model" not in last_models: # no best model - # this shouldn't happen, but let's handle it just in case - last_models["best_model"] = None - # finally check if last best model is more recent than checkpoint - elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: - last_models["checkpoint"] = last_models["best_model"] - - return last_models["checkpoint"], last_models["best_model"] - - -def process_args(args): - """Process parsed comand line arguments. - - Args: - args (argparse.Namespace or dict like): Parsed input arguments. - - Returns: - c (TTS.utils.io.AttrDict): Config paramaters. - out_path (str): Path to save models and logging. - audio_path (str): Path to save generated test audios. - c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does - logging to the console. - tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does - the TensorBoard loggind. - """ - if isinstance(args, tuple): - args, coqpit_overrides = args - if args.continue_path: - # continue a previous training from its output folder - experiment_path = args.continue_path - args.config_path = os.path.join(args.continue_path, "config.json") - args.restore_path, best_model = get_last_checkpoint(args.continue_path) - if not args.best_path: - args.best_path = best_model - # setup output paths and read configs - config = load_config(args.config_path) - # override values from command-line args - config.parse_known_args(coqpit_overrides, relaxed_parser=True) - if config.mixed_precision: - print(" > Mixed precision mode is ON") - experiment_path = args.continue_path - if not experiment_path: - experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug) - audio_path = os.path.join(experiment_path, "test_audios") - # setup rank 0 process in distributed training - tb_logger = None - if args.rank == 0: - os.makedirs(audio_path, exist_ok=True) - new_fields = {} - if args.restore_path: - new_fields["restore_path"] = args.restore_path - new_fields["github_branch"] = get_git_branch() - # if model characters are not set in the config file - # save the default set to the config file for future - # compatibility. - if config.has("characters_config"): - used_characters = parse_symbols() - new_fields["characters"] = used_characters - copy_model_files(config, experiment_path, new_fields) - os.chmod(audio_path, 0o775) - os.chmod(experiment_path, 0o775) - tb_logger = TensorboardLogger(experiment_path, model_name=config.model) - # write model desc to tensorboard - tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) - c_logger = ConsoleLogger() - return config, experiment_path, audio_path, c_logger, tb_logger - - -def init_training(argv): - """Initialization of a training run.""" - parser = init_arguments(argv) - args = parser.parse_known_args() - config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args) - return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger diff --git a/TTS/utils/callbacks.py b/TTS/utils/callbacks.py new file mode 100644 index 00000000..18b6c34c --- /dev/null +++ b/TTS/utils/callbacks.py @@ -0,0 +1,75 @@ +class TrainerCallback: + def __init__(self, trainer): + super().__init__() + self.trainer = trainer + + def on_init_start(self) -> None: + if hasattr(self.trainer.model, "on_init_start"): + self.trainer.model.on_init_start(self.trainer) + + if hasattr(self.trainer.criterion, "on_init_start"): + self.trainer.criterion.on_init_start(self.trainer) + + if hasattr(self.trainer.optimizer, "on_init_start"): + self.trainer.optimizer.on_init_start(self.trainer) + + def on_init_end(self) -> None: + if hasattr(self.trainer.model, "on_init_end"): + self.trainer.model.on_init_end(self.trainer) + + if hasattr(self.trainer.criterion, "on_init_end"): + self.trainer.criterion.on_init_end(self.trainer) + + if hasattr(self.trainer.optimizer, "on_init_end"): + self.trainer.optimizer.on_init_end(self.trainer) + + def on_epoch_start(self) -> None: + if hasattr(self.trainer.model, "on_epoch_start"): + self.trainer.model.on_epoch_start(self.trainer) + + if hasattr(self.trainer.criterion, "on_epoch_start"): + self.trainer.criterion.on_epoch_start(self.trainer) + + if hasattr(self.trainer.optimizer, "on_epoch_start"): + self.trainer.optimizer.on_epoch_start(self.trainer) + + def on_epoch_end(self) -> None: + if hasattr(self.trainer.model, "on_epoch_end"): + self.trainer.model.on_epoch_end(self.trainer) + + if hasattr(self.trainer.criterion, "on_epoch_end"): + self.trainer.criterion.on_epoch_end(self.trainer) + + if hasattr(self.trainer.optimizer, "on_epoch_end"): + self.trainer.optimizer.on_epoch_end(self.trainer) + + def on_train_step_start(self) -> None: + if hasattr(self.trainer.model, "on_train_step_start"): + self.trainer.model.on_train_step_start(self.trainer) + + if hasattr(self.trainer.criterion, "on_train_step_start"): + self.trainer.criterion.on_train_step_start(self.trainer) + + if hasattr(self.trainer.optimizer, "on_train_step_start"): + self.trainer.optimizer.on_train_step_start(self.trainer) + + def on_train_step_end(self) -> None: + + if hasattr(self.trainer.model, "on_train_step_end"): + self.trainer.model.on_train_step_end(self.trainer) + + if hasattr(self.trainer.criterion, "on_train_step_end"): + self.trainer.criterion.on_train_step_end(self.trainer) + + if hasattr(self.trainer.optimizer, "on_train_step_end"): + self.trainer.optimizer.on_train_step_end(self.trainer) + + def on_keyboard_interrupt(self) -> None: + if hasattr(self.trainer.model, "on_keyboard_interrupt"): + self.trainer.model.on_keyboard_interrupt(self.trainer) + + if hasattr(self.trainer.criterion, "on_keyboard_interrupt"): + self.trainer.criterion.on_keyboard_interrupt(self.trainer) + + if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"): + self.trainer.optimizer.on_keyboard_interrupt(self.trainer) diff --git a/TTS/utils/distribute.py b/TTS/utils/distribute.py index 7a1078e8..1c6b0e1c 100644 --- a/TTS/utils/distribute.py +++ b/TTS/utils/distribute.py @@ -1,53 +1,8 @@ # edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py -import math - import torch import torch.distributed as dist from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.autograd import Variable -from torch.utils.data.sampler import Sampler - - -class DistributedSampler(Sampler): - """ - Non shuffling Distributed Sampler - """ - - def __init__(self, dataset, num_replicas=None, rank=None): - super().__init__(dataset) - if num_replicas is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - num_replicas = dist.get_world_size() - if rank is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - rank = dist.get_rank() - self.dataset = dataset - self.num_replicas = num_replicas - self.rank = rank - self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) - self.total_size = self.num_samples * self.num_replicas - - def __iter__(self): - indices = torch.arange(len(self.dataset)).tolist() - - # add extra samples to make it evenly divisible - indices += indices[: (self.total_size - len(indices))] - assert len(indices) == self.total_size - - # subsample - indices = indices[self.rank : self.total_size : self.num_replicas] - assert len(indices) == self.num_samples - - return iter(indices) - - def __len__(self): - return self.num_samples - - def set_epoch(self, epoch): - self.epoch = epoch def reduce_tensor(tensor, num_gpus): diff --git a/TTS/utils/trainer_utils.py b/TTS/utils/trainer_utils.py new file mode 100644 index 00000000..02e68905 --- /dev/null +++ b/TTS/utils/trainer_utils.py @@ -0,0 +1,65 @@ +import importlib +from typing import Dict + +import torch + +from TTS.utils.training import NoamLR + + +def is_apex_available(): + return importlib.util.find_spec("apex") is not None + + +def setup_torch_training_env(cudnn_enable, cudnn_benchmark): + torch.backends.cudnn.enabled = cudnn_enable + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.manual_seed(54321) + use_cuda = torch.cuda.is_available() + num_gpus = torch.cuda.device_count() + print(" > Using CUDA: ", use_cuda) + print(" > Number of GPUs: ", num_gpus) + return use_cuda, num_gpus + + +def get_scheduler( + lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer +) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access + """Find, initialize and return a scheduler. + + Args: + lr_scheduler (str): Scheduler name. + lr_scheduler_params (Dict): Scheduler parameters. + optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler. + + Returns: + torch.optim.lr_scheduler._LRScheduler: Functional scheduler. + """ + if lr_scheduler is None: + return None + if lr_scheduler.lower() == "noamlr": + scheduler = NoamLR + else: + scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler) + return scheduler(optimizer, **lr_scheduler_params) + + +def get_optimizer( + optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module +) -> torch.optim.Optimizer: + """Find, initialize and return a optimizer. + + Args: + optimizer_name (str): Optimizer name. + optimizer_params (dict): Optimizer parameters. + lr (float): Initial learning rate. + model (torch.nn.Module): Model to pass to the optimizer. + + Returns: + torch.optim.Optimizer: Functional optimizer. + """ + if optimizer_name.lower() == "radam": + module = importlib.import_module("TTS.utils.radam") + optimizer = getattr(module, "RAdam") + else: + optimizer = getattr(torch.optim, optimizer_name) + return optimizer(model.parameters(), lr=lr, **optimizer_params) diff --git a/TTS/utils/training.py b/TTS/utils/training.py index 37b32637..aa5651c5 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -2,17 +2,6 @@ import numpy as np import torch -def setup_torch_training_env(cudnn_enable, cudnn_benchmark): - torch.backends.cudnn.enabled = cudnn_enable - torch.backends.cudnn.benchmark = cudnn_benchmark - torch.manual_seed(54321) - use_cuda = torch.cuda.is_available() - num_gpus = torch.cuda.device_count() - print(" > Using CUDA: ", use_cuda) - print(" > Number of GPUs: ", num_gpus) - return use_cuda, num_gpus - - def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): r"""Check model gradient against unexpected jumps and failures""" skip_flag = False @@ -41,46 +30,6 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): return grad_norm, skip_flag -def lr_decay(init_lr, global_step, warmup_steps): - r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py""" - warmup_steps = float(warmup_steps) - step = global_step + 1.0 - lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5) - return lr - - -def adam_weight_decay(optimizer): - """ - Custom weight decay operation, not effecting grad values. - """ - for group in optimizer.param_groups: - for param in group["params"]: - current_lr = group["lr"] - weight_decay = group["weight_decay"] - factor = -weight_decay * group["lr"] - param.data = param.data.add(param.data, alpha=factor) - return optimizer, current_lr - - -# pylint: disable=dangerous-default-value -def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): - """ - Skip biases, BatchNorm parameters, rnns. - and attention projection layer v - """ - decay = [] - no_decay = [] - for name, param in model.named_parameters(): - if not param.requires_grad: - continue - - if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)): - no_decay.append(param) - else: - decay.append(param) - return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}] - - # pylint: disable=protected-access class NoamLR(torch.optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): @@ -107,3 +56,31 @@ def gradual_training_scheduler(global_step, config): if global_step * num_gpus >= values[0]: new_values = values return new_values[1], new_values[2] + + +def lr_decay(init_lr, global_step, warmup_steps): + r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py + It is only being used by the Speaker Encoder trainer.""" + warmup_steps = float(warmup_steps) + step = global_step + 1.0 + lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5) + return lr + + +# pylint: disable=dangerous-default-value +def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): + """ + Skip biases, BatchNorm parameters, rnns. + and attention projection layer v + """ + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)): + no_decay.append(param) + else: + decay.append(param) + return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]