import argparse import math import os import pickle import shutil import sys import traceback import time import glob import random import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.radam import RAdam from TTS.utils.io import copy_config_file, load_config from TTS.utils.training import setup_torch_training_env from TTS.utils.console_logger import ConsoleLogger from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.generic_utils import ( KeepAverage, count_parameters, create_experiment_folder, get_git_branch, remove_experiment_folder, set_init_dict, ) from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.preprocess import ( load_wav_data, find_feat_files, load_wav_feat_data, preprocess_wav_files, ) from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss from TTS.vocoder.utils.generic_utils import setup_wavernn from TTS.vocoder.utils.io import save_best_model, save_checkpoint use_cuda, num_gpus = setup_torch_training_env(True, True) def setup_loader(ap, is_val=False, verbose=False): if is_val and not CONFIG.run_eval: loader = None else: dataset = WaveRNNDataset( ap=ap, items=eval_data if is_val else train_data, seq_len=CONFIG.seq_len, hop_len=ap.hop_length, pad=CONFIG.padding, mode=CONFIG.mode, is_training=not is_val, verbose=verbose, ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None loader = DataLoader( dataset, shuffle=True, collate_fn=dataset.collate, batch_size=CONFIG.batch_size, num_workers=CONFIG.num_val_loader_workers if is_val else CONFIG.num_loader_workers, pin_memory=True, ) return loader def format_data(data): # setup input data x = data[0] m = data[1] y = data[2] # dispatch data to GPU if use_cuda: x = x.cuda(non_blocking=True) m = m.cuda(non_blocking=True) y = y.cuda(non_blocking=True) return x, m, y def train(model, optimizer, criterion, scheduler, ap, global_step, epoch): # create train loader data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: batch_n_iter = int(len(data_loader.dataset) / (CONFIG.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / CONFIG.batch_size) end_time = time.time() c_logger.print_train_start() # train loop print(" > Training", flush=True) for num_iter, data in enumerate(data_loader): start_time = time.time() x, m, y = format_data(data) loader_time = time.time() - end_time global_step += 1 ################## # MODEL TRAINING # ################## y_hat = model(x, m) if isinstance(model.mode, int): y_hat = y_hat.transpose(1, 2).unsqueeze(-1) else: y = y.float() y = y.unsqueeze(-1) # m_scaled, _ = model.upsample(m) # compute losses loss = criterion(y_hat, y) if loss.item() is None: raise RuntimeError(" [!] None loss. Exiting ...") optimizer.zero_grad() loss.backward() if CONFIG.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG.grad_clip) optimizer.step() if scheduler is not None: scheduler.step() # get the current learning rate cur_lr = list(optimizer.param_groups)[0]["lr"] step_time = time.time() - start_time epoch_time += step_time update_train_values = dict() loss_dict = dict() loss_dict["model_loss"] = loss.item() for key, value in loss_dict.items(): update_train_values["avg_" + key] = value update_train_values["avg_loader_time"] = loader_time update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training stats if global_step % CONFIG.print_step == 0: log_dict = { "step_time": [step_time, 2], "loader_time": [loader_time, 4], "current_lr": cur_lr, } c_logger.print_train_step( batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values, ) # plot step stats if global_step % 10 == 0: iter_stats = {"lr": cur_lr, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) # save checkpoint if global_step % CONFIG.save_step == 0: if CONFIG.checkpoint: # save model save_checkpoint( model, optimizer, scheduler, None, None, None, global_step, epoch, OUT_PATH, model_losses=loss_dict, ) # synthesize a full voice wav_path = train_data[random.randrange(0, len(train_data))][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) sample_wav = model.generate( ground_mel, CONFIG.batched, CONFIG.target_samples, CONFIG.overlap_samples, ) predict_mel = ap.melspectrogram(sample_wav) # compute spectrograms figures = { "train/ground_truth": plot_spectrogram(ground_mel.T), "train/prediction": plot_spectrogram(predict_mel.T), } # Sample audio tb_logger.tb_train_audios( global_step, {"train/audio": sample_wav}, CONFIG.audio["sample_rate"] ) tb_logger.tb_train_figures(global_step, figures) end_time = time.time() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Training Epoch Stats epoch_stats = {"epoch_time": epoch_time} epoch_stats.update(keep_avg.avg_values) tb_logger.tb_train_epoch_stats(global_step, epoch_stats) # TODO: plot model stats # if c.tb_model_param_stats: # tb_logger.tb_model_weights(model, global_step) return keep_avg.avg_values, global_step @torch.no_grad() def evaluate(model, criterion, ap, global_step, epoch): # create train loader data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0)) model.eval() epoch_time = 0 keep_avg = KeepAverage() end_time = time.time() c_logger.print_eval_start() with torch.no_grad(): for num_iter, data in enumerate(data_loader): start_time = time.time() # format data x, m, y = format_data(data) loader_time = time.time() - end_time global_step += 1 y_hat = model(x, m) if isinstance(model.mode, int): y_hat = y_hat.transpose(1, 2).unsqueeze(-1) else: y = y.float() y = y.unsqueeze(-1) loss = criterion(y_hat, y) # Compute avg loss # if num_gpus > 1: # loss = reduce_tensor(loss.data, num_gpus) loss_dict = dict() loss_dict["model_loss"] = loss.item() step_time = time.time() - start_time epoch_time += step_time # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): update_eval_values["avg_" + key] = value update_eval_values["avg_loader_time"] = loader_time update_eval_values["avg_step_time"] = step_time keep_avg.update_values(update_eval_values) # print eval stats if CONFIG.print_eval: c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if epoch % CONFIG.test_every_epochs == 0: # synthesize a part of data wav_path = eval_data[random.randrange(0, len(eval_data))][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav[:22000]) sample_wav = model.generate( ground_mel, CONFIG.batched, CONFIG.target_samples, CONFIG.overlap_samples, ) predict_mel = ap.melspectrogram(sample_wav) # compute spectrograms figures = { "eval/ground_truth": plot_spectrogram(ground_mel.T), "eval/prediction": plot_spectrogram(predict_mel.T), } # Sample audio tb_logger.tb_eval_audios( global_step, {"eval/audio": sample_wav}, CONFIG.audio["sample_rate"] ) tb_logger.tb_eval_figures(global_step, figures) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) return keep_avg.avg_values # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global train_data, eval_data # setup audio processor ap = AudioProcessor(**CONFIG.audio) print(f" > Loading wavs from: {CONFIG.data_path}") if CONFIG.feature_path is not None: print(f" > Loading features from: {CONFIG.feature_path}") eval_data, train_data = load_wav_feat_data( CONFIG.data_path, CONFIG.feature_path, CONFIG.eval_split_size ) else: mel_feat_path = os.path.join(OUT_PATH, "mel") feat_data = find_feat_files(mel_feat_path) if feat_data: print(f" > Loading features from: {mel_feat_path}") eval_data, train_data = load_wav_feat_data( CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size ) else: print(f" > No feature data found. Preprocessing...") # preprocessing feature data from given wav files preprocess_wav_files(OUT_PATH, CONFIG, ap) eval_data, train_data = load_wav_feat_data( CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size ) # setup model model_wavernn = setup_wavernn(CONFIG) # define train functions if CONFIG.mode == "mold": criterion = discretized_mix_logistic_loss elif CONFIG.mode == "gauss": criterion = gaussian_loss elif isinstance(CONFIG.mode, int): criterion = torch.nn.CrossEntropyLoss() if use_cuda: model_wavernn.cuda() if isinstance(CONFIG.mode, int): criterion.cuda() optimizer = RAdam(model_wavernn.parameters(), lr=CONFIG.lr, weight_decay=0) scheduler = None if "lr_scheduler" in CONFIG: scheduler = getattr(torch.optim.lr_scheduler, CONFIG.lr_scheduler) scheduler = scheduler(optimizer, **CONFIG.lr_scheduler_params) # slow start for the first 5 epochs # lr_lambda = lambda epoch: min(epoch / CONFIG.warmup_steps, 1) # scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # restore any checkpoint if args.restore_path: checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Model...") model_wavernn.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") optimizer.load_state_dict(checkpoint["optimizer"]) if "scheduler" in checkpoint: print(" > Restoring Generator LR Scheduler...") scheduler.load_state_dict(checkpoint["scheduler"]) scheduler.optimizer = optimizer # TODO: fix resetting restored optimizer lr # optimizer.load_state_dict(checkpoint["optimizer"]) except RuntimeError: # retore only matching layers. print(" > Partial model initialization...") model_dict = model_wavernn.state_dict() model_dict = set_init_dict(model_dict, checkpoint["model"], CONFIG) model_wavernn.load_state_dict(model_dict) print(" > Model restored from step %d" % checkpoint["step"], flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 # DISTRIBUTED # if num_gpus > 1: # model = apply_gradient_allreduce(model) num_parameters = count_parameters(model_wavernn) print(" > Model has {} parameters".format(num_parameters), flush=True) if "best_loss" not in locals(): best_loss = float("inf") global_step = args.restore_step for epoch in range(0, CONFIG.epochs): c_logger.print_epoch_start(epoch, CONFIG.epochs) _, global_step = train( model_wavernn, optimizer, criterion, scheduler, ap, global_step, epoch ) eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict["avg_model_loss"] best_loss = save_best_model( target_loss, best_loss, model_wavernn, optimizer, scheduler, None, None, None, global_step, epoch, OUT_PATH, model_losses=eval_avg_loss_dict, ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--continue_path", type=str, help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', default="", required="--config_path" not in sys.argv, ) parser.add_argument( "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="", ) parser.add_argument( "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv, ) parser.add_argument( "--debug", type=bool, default=False, help="Do not verify commit integrity to run training.", ) # DISTRUBUTED parser.add_argument( "--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.", ) parser.add_argument( "--group_id", type=str, default="", help="DISTRIBUTED: process group id." ) args = parser.parse_args() if args.continue_path != "": args.output_path = args.continue_path args.config_path = os.path.join(args.continue_path, "config.json") list_of_files = glob.glob( args.continue_path + "/*.pth.tar" ) # * means all if need specific format then *.csv latest_model_file = max(list_of_files, key=os.path.getctime) args.restore_path = latest_model_file print(f" > Training continues for {args.restore_path}") # setup output paths and read configs CONFIG = load_config(args.config_path) # check_config(c) _ = os.path.dirname(os.path.realpath(__file__)) OUT_PATH = args.continue_path if args.continue_path == "": OUT_PATH = create_experiment_folder( CONFIG.output_path, CONFIG.run_name, args.debug ) AUDIO_PATH = os.path.join(OUT_PATH, "test_audios") c_logger = ConsoleLogger() if args.rank == 0: os.makedirs(AUDIO_PATH, exist_ok=True) new_fields = {} if args.restore_path: new_fields["restore_path"] = args.restore_path new_fields["github_branch"] = get_git_branch() copy_config_file( args.config_path, os.path.join(OUT_PATH, "config.json"), new_fields ) os.chmod(AUDIO_PATH, 0o775) os.chmod(OUT_PATH, 0o775) LOG_DIR = OUT_PATH tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER") # write model desc to tensorboard tb_logger.tb_add_text("model-description", CONFIG["run_description"], 0) try: main(args) except KeyboardInterrupt: remove_experiment_folder(OUT_PATH) try: sys.exit(0) except SystemExit: os._exit(0) # pylint: disable=protected-access except Exception: # pylint: disable=broad-except remove_experiment_folder(OUT_PATH) traceback.print_exc() sys.exit(1)