#!/usr/bin/env python3 # TODO: mixed precision training """Trains GAN based vocoder model.""" import itertools import os import sys import time import traceback from inspect import signature import torch # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.utils.arguments import init_training from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.utils.generic_utils import plot_results, setup_discriminator, setup_generator from TTS.vocoder.utils.io import save_best_model, save_checkpoint use_cuda, num_gpus = setup_torch_training_env(True, True) def setup_loader(ap, is_val=False, verbose=False): loader = None if not is_val or c.run_eval: dataset = GANDataset( ap=ap, items=eval_data if is_val else train_data, seq_len=c.seq_len, hop_len=ap.hop_length, pad_short=c.pad_short, conv_pad=c.conv_pad, return_pairs=c.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in c else False, is_training=not is_val, return_segments=not is_val, use_noise_augment=c.use_noise_augment, use_cache=c.use_cache, verbose=verbose, ) dataset.shuffle_mapping() sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None loader = DataLoader( dataset, batch_size=1 if is_val else c.batch_size, shuffle=num_gpus == 0, drop_last=False, sampler=sampler, num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, pin_memory=False, ) return loader def format_data(data): if isinstance(data[0], list): x_G, y_G = data[0] x_D, y_D = data[1] if use_cuda: x_G = x_G.cuda(non_blocking=True) y_G = y_G.cuda(non_blocking=True) x_D = x_D.cuda(non_blocking=True) y_D = y_D.cuda(non_blocking=True) return x_G, y_G, x_D, y_D x, y = data if use_cuda: x = x.cuda(non_blocking=True) y = y.cuda(non_blocking=True) return x, y, None, None def train( model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, scheduler_G, scheduler_D, ap, global_step, epoch, ): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model_G.train() model_D.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() for num_iter, data in enumerate(data_loader): start_time = time.time() # format data c_G, y_G, c_D, y_D = format_data(data) loader_time = time.time() - end_time global_step += 1 ############################## # GENERATOR ############################## # generator pass y_hat = model_G(c_G) y_hat_sub = None y_G_sub = None y_hat_vis = y_hat # for visualization # PQMF formatting if y_hat.shape[1] > 1: y_hat_sub = y_hat y_hat = model_G.pqmf_synthesis(y_hat) y_hat_vis = y_hat y_G_sub = model_G.pqmf_analysis(y_G) scores_fake, feats_fake, feats_real = None, None, None if global_step > c.steps_to_start_discriminator: # run D with or without cond. features if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat, c_G) else: D_out_fake = model_D(y_hat) D_out_real = None if c.use_feat_match_loss: with torch.no_grad(): D_out_real = model_D(y_G) # format D outputs if isinstance(D_out_fake, tuple): scores_fake, feats_fake = D_out_fake if D_out_real is None: feats_real = None else: # we don't need scores for real samples for training G since they are always 1 _, feats_real = D_out_real else: scores_fake = D_out_fake # compute losses loss_G_dict = criterion_G( y_hat=y_hat, y=y_G, scores_fake=scores_fake, feats_fake=feats_fake, feats_real=feats_real, y_hat_sub=y_hat_sub, y_sub=y_G_sub, ) loss_G = loss_G_dict["G_loss"] # optimizer generator optimizer_G.zero_grad() loss_G.backward() if c.gen_clip_grad > 0: torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad) optimizer_G.step() loss_dict = dict() for key, value in loss_G_dict.items(): if isinstance(value, int): loss_dict[key] = value else: loss_dict[key] = value.item() ############################## # DISCRIMINATOR ############################## if global_step >= c.steps_to_start_discriminator: # discriminator pass if c.diff_samples_for_G_and_D: # use a different sample than generator with torch.no_grad(): y_hat = model_G(c_D) # PQMF formatting if y_hat.shape[1] > 1: y_hat = model_G.pqmf_synthesis(y_hat) else: # use the same samples as generator c_D = c_G.clone() y_D = y_G.clone() # run D with or without cond. features if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat.detach().clone(), c_D) D_out_real = model_D(y_D, c_D) else: D_out_fake = model_D(y_hat.detach()) D_out_real = model_D(y_D) # format D outputs if isinstance(D_out_fake, tuple): # model_D returns scores and features scores_fake, feats_fake = D_out_fake if D_out_real is None: scores_real, feats_real = None, None else: scores_real, feats_real = D_out_real else: # model D returns only scores scores_fake = D_out_fake scores_real = D_out_real # compute losses loss_D_dict = criterion_D(scores_fake, scores_real) loss_D = loss_D_dict["D_loss"] # optimizer discriminator optimizer_D.zero_grad() loss_D.backward() if c.disc_clip_grad > 0: torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad) optimizer_D.step() for key, value in loss_D_dict.items(): if isinstance(value, (int, float)): loss_dict[key] = value else: loss_dict[key] = value.item() step_time = time.time() - start_time epoch_time += step_time # get current learning rates current_lr_G = list(optimizer_G.param_groups)[0]["lr"] current_lr_D = list(optimizer_D.param_groups)[0]["lr"] # update avg stats update_train_values = dict() for key, value in loss_dict.items(): update_train_values["avg_" + key] = value update_train_values["avg_loader_time"] = loader_time update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training stats if global_step % c.print_step == 0: log_dict = { "step_time": [step_time, 2], "loader_time": [loader_time, 4], "current_lr_G": current_lr_G, "current_lr_D": current_lr_D, } c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # plot step stats if global_step % 10 == 0: iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_step_stats(global_step, iter_stats) # save checkpoint if global_step % c.save_step == 0: if c.checkpoint: # save model save_checkpoint( model_G, optimizer_G, scheduler_G, model_D, optimizer_D, scheduler_D, global_step, epoch, OUT_PATH, model_losses=loss_dict, ) # compute spectrograms figures = plot_results(y_hat_vis, y_G, ap, global_step, "train") tb_logger.tb_train_figures(global_step, figures) # Sample audio sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() tb_logger.tb_train_audios(global_step, {"train/audio": sample_voice}, c.audio["sample_rate"]) end_time = time.time() if scheduler_G is not None: scheduler_G.step() if scheduler_D is not None: scheduler_D.step() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Training Epoch Stats epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(keep_avg.avg_values) if args.rank == 0: tb_logger.tb_train_epoch_stats(global_step, epoch_stats) # TODO: plot model stats # if c.tb_model_param_stats: # tb_logger.tb_model_weights(model, global_step) torch.cuda.empty_cache() return keep_avg.avg_values, global_step @torch.no_grad() def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) model_G.eval() model_D.eval() epoch_time = 0 keep_avg = KeepAverage() end_time = time.time() c_logger.print_eval_start() for num_iter, data in enumerate(data_loader): start_time = time.time() # format data c_G, y_G, _, _ = format_data(data) loader_time = time.time() - end_time global_step += 1 ############################## # GENERATOR ############################## # generator pass y_hat = model_G(c_G)[:, :, : y_G.size(2)] y_hat_sub = None y_G_sub = None # PQMF formatting if y_hat.shape[1] > 1: y_hat_sub = y_hat y_hat = model_G.pqmf_synthesis(y_hat) y_G_sub = model_G.pqmf_analysis(y_G) scores_fake, feats_fake, feats_real = None, None, None if global_step > c.steps_to_start_discriminator: if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat, c_G) else: D_out_fake = model_D(y_hat) D_out_real = None if c.use_feat_match_loss: with torch.no_grad(): D_out_real = model_D(y_G) # format D outputs if isinstance(D_out_fake, tuple): scores_fake, feats_fake = D_out_fake if D_out_real is None: feats_real = None else: _, feats_real = D_out_real else: scores_fake = D_out_fake feats_fake, feats_real = None, None # compute losses loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub) loss_dict = dict() for key, value in loss_G_dict.items(): if isinstance(value, (int, float)): loss_dict[key] = value else: loss_dict[key] = value.item() ############################## # DISCRIMINATOR ############################## if global_step >= c.steps_to_start_discriminator: # discriminator pass with torch.no_grad(): y_hat = model_G(c_G)[:, :, : y_G.size(2)] # PQMF formatting if y_hat.shape[1] > 1: y_hat = model_G.pqmf_synthesis(y_hat) # run D with or without cond. features if len(signature(model_D.forward).parameters) == 2: D_out_fake = model_D(y_hat.detach(), c_G) D_out_real = model_D(y_G, c_G) else: D_out_fake = model_D(y_hat.detach()) D_out_real = model_D(y_G) # format D outputs if isinstance(D_out_fake, tuple): scores_fake, feats_fake = D_out_fake if D_out_real is None: scores_real, feats_real = None, None else: scores_real, feats_real = D_out_real else: scores_fake = D_out_fake scores_real = D_out_real # compute losses loss_D_dict = criterion_D(scores_fake, scores_real) for key, value in loss_D_dict.items(): if isinstance(value, (int, float)): loss_dict[key] = value else: loss_dict[key] = value.item() step_time = time.time() - start_time epoch_time += step_time # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): update_eval_values["avg_" + key] = value update_eval_values["avg_loader_time"] = loader_time update_eval_values["avg_step_time"] = step_time keep_avg.update_values(update_eval_values) # print eval stats if c.print_eval: c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if args.rank == 0: # compute spectrograms figures = plot_results(y_hat, y_G, ap, global_step, "eval") tb_logger.tb_eval_figures(global_step, figures) # Sample audio predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy() real_waveform = y_G[0].squeeze(0).cpu().numpy() tb_logger.tb_eval_audios( global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"] ) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) # synthesize a full voice data_loader.return_segments = False torch.cuda.empty_cache() return keep_avg.avg_values def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup audio processor ap = AudioProcessor(**c.audio.to_dict()) # DISTRUBUTED if num_gpus > 1: init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # setup models model_gen = setup_generator(c) model_disc = setup_discriminator(c) # setup criterion criterion_gen = GeneratorLoss(c) criterion_disc = DiscriminatorLoss(c) if use_cuda: model_gen.cuda() criterion_gen.cuda() model_disc.cuda() criterion_disc.cuda() # setup optimizers # TODO: allow loading custom optimizers optimizer_gen = None optimizer_disc = None optimizer_gen = getattr(torch.optim, c.optimizer) optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params) optimizer_disc = getattr(torch.optim, c.optimizer) if c.discriminator_model == "hifigan_discriminator": optimizer_disc = optimizer_disc( itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params, ) else: optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params) # schedulers scheduler_gen = None scheduler_disc = None if "lr_scheduler_gen" in c: scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) if "lr_scheduler_disc" in c: scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Generator Model...") model_gen.load_state_dict(checkpoint["model"]) print(" > Restoring Generator Optimizer...") optimizer_gen.load_state_dict(checkpoint["optimizer"]) print(" > Restoring Discriminator Model...") model_disc.load_state_dict(checkpoint["model_disc"]) print(" > Restoring Discriminator Optimizer...") optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) # restore schedulers if it is a continuing training. if args.continue_path != "": if "scheduler" in checkpoint and scheduler_gen is not None: print(" > Restoring Generator LR Scheduler...") scheduler_gen.load_state_dict(checkpoint["scheduler"]) # NOTE: Not sure if necessary scheduler_gen.optimizer = optimizer_gen if "scheduler_disc" in checkpoint and scheduler_disc is not None: print(" > Restoring Discriminator LR Scheduler...") scheduler_disc.load_state_dict(checkpoint["scheduler_disc"]) scheduler_disc.optimizer = optimizer_disc if c.lr_scheduler_disc == "ExponentialLR": scheduler_disc.last_epoch = checkpoint["epoch"] except RuntimeError: # restore only matching layers. print(" > Partial model initialization...") model_dict = model_gen.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_gen.load_state_dict(model_dict) model_dict = model_disc.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c) model_disc.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. if args.continue_path == "": for group in optimizer_gen.param_groups: group["lr"] = c.lr_gen for group in optimizer_disc.param_groups: group["lr"] = c.lr_disc print(f" > Model restored from step {checkpoint['step']:d}", flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 # DISTRUBUTED if num_gpus > 1: model_gen = DDP_th(model_gen, device_ids=[args.rank]) model_disc = DDP_th(model_disc, device_ids=[args.rank]) num_params = count_parameters(model_gen) print(" > Generator has {} parameters".format(num_params), flush=True) num_params = count_parameters(model_disc) print(" > Discriminator has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: best_loss = float("inf") print(" > Starting with inf best loss.") else: print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with best loss of {best_loss}.") keep_all_best = c.get("keep_all_best", False) keep_after = c.get("keep_after", 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) _, global_step = train( model_gen, criterion_gen, optimizer_gen, model_disc, criterion_disc, optimizer_disc, scheduler_gen, scheduler_disc, ap, global_step, epoch, ) eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] best_loss = save_best_model( target_loss, best_loss, model_gen, optimizer_gen, scheduler_gen, model_disc, optimizer_disc, scheduler_disc, global_step, epoch, OUT_PATH, keep_all_best=keep_all_best, keep_after=keep_after, model_losses=eval_avg_loss_dict, ) if __name__ == "__main__": args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) try: main(args) except KeyboardInterrupt: remove_experiment_folder(OUT_PATH) try: sys.exit(0) except SystemExit: os._exit(0) # pylint: disable=protected-access except Exception: # pylint: disable=broad-except remove_experiment_folder(OUT_PATH) traceback.print_exc() sys.exit(1)