From 8e52a6923035c729244532d506b557465f2a5927 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 28 May 2021 14:03:02 +0200 Subject: [PATCH] delete separate tts training scripts and pre-commit configuration --- .pre-commit-config.yaml | 17 +- TTS/bin/train_align_tts.py | 572 ------------------------- TTS/bin/train_glow_tts.py | 598 -------------------------- TTS/bin/train_speedy_speech.py | 578 ------------------------- TTS/bin/train_tacotron.py | 749 --------------------------------- 5 files changed, 16 insertions(+), 2498 deletions(-) delete mode 100644 TTS/bin/train_align_tts.py delete mode 100644 TTS/bin/train_glow_tts.py delete mode 100644 TTS/bin/train_speedy_speech.py delete mode 100755 TTS/bin/train_tacotron.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1ae28644..a70572dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,4 +9,19 @@ repos: rev: 20.8b1 hooks: - id: black - language_version: python3 \ No newline at end of file + language_version: python3 + - repo: https://github.com/pycqa/isort + rev: 5.8.0 + hooks: + - id: isort + name: isort (python) + - id: isort + name: isort (cython) + types: [cython] + - id: isort + name: isort (pyi) + types: [pyi] + - repo: https://github.com/pycqa/pylint + rev: v2.8.2 + hooks: + - id: pylint diff --git a/TTS/bin/train_align_tts.py b/TTS/bin/train_align_tts.py deleted file mode 100644 index 34eba7a8..00000000 --- a/TTS/bin/train_align_tts.py +++ /dev/null @@ -1,572 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import os -import sys -import time -import traceback -from random import randrange - -import numpy as np -import torch -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from TTS.tts.datasets import load_meta_data -from TTS.tts.datasets.TTSDataset import TTSDataset -from TTS.tts.layers.losses import AlignTTSLoss -from TTS.tts.models import setup_model -from TTS.tts.utils.io import save_best_model, save_checkpoint -from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import parse_speakers -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.arguments import init_training -from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.radam import RAdam -from TTS.utils.training import NoamLR, setup_torch_training_env - -use_cuda, num_gpus = setup_torch_training_env(True, False) -# torch.autograd.set_detect_anomaly(True) - - -def setup_loader(ap, r, is_val=False, verbose=False): - if is_val and not config.run_eval: - loader = None - else: - dataset = TTSDataset( - r, - config.text_cleaner, - compute_linear_spec=False, - meta_data=meta_data_eval if is_val else meta_data_train, - ap=ap, - tp=config.characters, - add_blank=config["add_blank"], - batch_group_size=0 if is_val else config.batch_group_size * config.batch_size, - min_seq_len=config.min_seq_len, - max_seq_len=config.max_seq_len, - phoneme_cache_path=config.phoneme_cache_path, - use_phonemes=config.use_phonemes, - phoneme_language=config.phoneme_language, - enable_eos_bos=config.enable_eos_bos_chars, - use_noise_augment=not is_val, - verbose=verbose, - speaker_mapping=speaker_mapping - if config.use_speaker_embedding and config.use_external_speaker_embedding_file - else None, - ) - - if config.use_phonemes and config.compute_input_seq_cache: - # precompute phonemes to have a better estimate of sequence lengths. - dataset.compute_input_seq(config.num_loader_workers) - dataset.sort_items() - - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=config.eval_batch_size if is_val else config.batch_size, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - sampler=sampler, - num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers, - pin_memory=False, - ) - return loader - - -def format_data(data): - # setup input data - text_input = data[0] - text_lengths = data[1] - speaker_names = data[2] - mel_input = data[4].permute(0, 2, 1) # B x D x T - mel_lengths = data[5] - item_idx = data[7] - avg_text_length = torch.mean(text_lengths.float()) - avg_spec_length = torch.mean(mel_lengths.float()) - - if config.use_speaker_embedding: - if config.use_external_speaker_embedding_file: - # return precomputed embedding vector - speaker_c = data[8] - else: - # return speaker_id to be used by an embedding layer - speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] - speaker_c = torch.LongTensor(speaker_c) - else: - speaker_c = None - # dispatch data to GPU - if use_cuda: - text_input = text_input.cuda(non_blocking=True) - text_lengths = text_lengths.cuda(non_blocking=True) - mel_input = mel_input.cuda(non_blocking=True) - mel_lengths = mel_lengths.cuda(non_blocking=True) - if speaker_c is not None: - speaker_c = speaker_c.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, speaker_c, avg_text_length, avg_spec_length, item_idx - - -def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, training_phase): - - model.train() - epoch_time = 0 - keep_avg = KeepAverage() - if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus)) - else: - batch_n_iter = int(len(data_loader.dataset) / config.batch_size) - end_time = time.time() - c_logger.print_train_start() - scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - ( - text_input, - text_lengths, - mel_targets, - mel_lengths, - speaker_c, - avg_text_length, - avg_spec_length, - _, - ) = format_data(data) - - loader_time = time.time() - end_time - - global_step += 1 - optimizer.zero_grad() - - # forward pass model - with torch.cuda.amp.autocast(enabled=config.mixed_precision): - decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward( - text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase - ) - - # compute loss - loss_dict = criterion( - logp, - decoder_output, - mel_targets, - mel_lengths, - dur_output, - dur_mas_output, - text_lengths, - global_step, - phase=training_phase, - ) - - # backward pass with loss scaling - if config.mixed_precision: - scaler.scale(loss_dict["loss"]).backward() - scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) - scaler.step(optimizer) - scaler.update() - else: - loss_dict["loss"].backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) - optimizer.step() - - # setup lr - if config.noam_schedule: - scheduler.step() - - # current_lr - current_lr = optimizer.param_groups[0]["lr"] - - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict["align_error"] = align_error - - step_time = time.time() - start_time - epoch_time += step_time - - # aggregate losses from processes - if num_gpus > 1: - loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) - loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) - loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) - loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - keep_avg.update_values(update_train_values) - - # print training progress - if global_step % config.print_step == 0: - log_dict = { - "avg_spec_length": [avg_spec_length, 1], # value, precision - "avg_text_length": [avg_text_length, 1], - "step_time": [step_time, 4], - "loader_time": [loader_time, 2], - "current_lr": current_lr, - } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # Plot Training Iter Stats - # reduce TB load - if global_step % config.tb_plot_step == 0: - iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} - iter_stats.update(loss_dict) - tb_logger.tb_train_step_stats(global_step, iter_stats) - - if global_step % config.save_step == 0: - if config.checkpoint: - # save model - save_checkpoint( - model, - optimizer, - global_step, - epoch, - 1, - OUT_PATH, - model_characters, - model_loss=loss_dict["loss"], - ) - - # wait all kernels to be completed - torch.cuda.synchronize() - - # Diagnostic visualizations - if decoder_output is not None: - idx = np.random.randint(mel_targets.shape[0]) - pred_spec = decoder_output[idx].detach().data.cpu().numpy().T - gt_spec = mel_targets[idx].data.cpu().numpy().T - align_img = alignments[idx].data.cpu() - - figures = { - "prediction": plot_spectrogram(pred_spec, ap), - "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img), - } - - tb_logger.tb_train_figures(global_step, figures) - - # Sample audio - train_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"]) - end_time = time.time() - - # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) - - # Plot Epoch Stats - if args.rank == 0: - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(keep_avg.avg_values) - tb_logger.tb_train_epoch_stats(global_step, epoch_stats) - if config.tb_model_param_stats: - tb_logger.tb_model_weights(model, global_step) - return keep_avg.avg_values, global_step - - -@torch.no_grad() -def evaluate(data_loader, model, criterion, ap, global_step, epoch, training_phase): - model.eval() - epoch_time = 0 - keep_avg = KeepAverage() - c_logger.print_eval_start() - if data_loader is not None: - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _ = format_data(data) - - # forward pass model - with torch.cuda.amp.autocast(enabled=config.mixed_precision): - decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward( - text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase - ) - - # compute loss - loss_dict = criterion( - logp, - decoder_output, - mel_targets, - mel_lengths, - dur_output, - dur_mas_output, - text_lengths, - global_step, - phase=training_phase, - ) - - # step time - step_time = time.time() - start_time - epoch_time += step_time - - # compute alignment score - align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict["align_error"] = align_error - - # aggregate losses from processes - if num_gpus > 1: - loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) - loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) - loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) - loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - keep_avg.update_values(update_train_values) - - if config.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # Diagnostic visualizations - idx = np.random.randint(mel_targets.shape[0]) - pred_spec = decoder_output[idx].detach().data.cpu().numpy().T - gt_spec = mel_targets[idx].data.cpu().numpy().T - align_img = alignments[idx].data.cpu() - - eval_figures = { - "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), - "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False), - } - - # Sample audio - eval_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"]) - - # Plot Validation Stats - tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) - tb_logger.tb_eval_figures(global_step, eval_figures) - - if args.rank == 0 and epoch >= config.test_delay_epochs: - if config.test_sentences_file: - with open(config.test_sentences_file, "r") as f: - test_sentences = [s.strip() for s in f.readlines()] - else: - test_sentences = [ - "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - "Be a voice, not an echo.", - "I'm sorry Dave. I'm afraid I can't do that.", - "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963.", - ] - - # test sentences - test_audios = {} - test_figures = {} - print(" | > Synthesizing test sentences") - if config.use_speaker_embedding: - if config.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][ - "embedding" - ] - speaker_id = None - else: - speaker_id = 0 - speaker_embedding = None - else: - speaker_id = None - speaker_embedding = None - - for idx, test_sentence in enumerate(test_sentences): - try: - wav, alignment, _, postnet_output, _, _ = synthesis( - model, - test_sentence, - config, - use_cuda, - ap, - speaker_id=speaker_id, - speaker_embedding=speaker_embedding, - style_wav=None, - truncated=False, - enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument - use_griffin_lim=True, - do_trim_silence=False, - ) - - file_path = os.path.join(AUDIO_PATH, str(global_step)) - os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) - ap.save_wav(wav, file_path) - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) - except: # pylint: disable=bare-except - print(" !! Error creating Test Sentence -", idx) - traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"]) - tb_logger.tb_test_figures(global_step, test_figures) - return keep_avg.avg_values - - -def main(args): # pylint: disable=redefined-outer-name - # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping - # Audio processor - ap = AudioProcessor(**config.audio.to_dict()) - if config.has("characters") and config.characters: - symbols, phonemes = make_symbols(**config.characters.to_dict()) - - # DISTRUBUTED - if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"]) - - # set model characters - model_characters = phonemes if config.use_phonemes else symbols - num_chars = len(model_characters) - - # load data instances - meta_data_train, meta_data_eval = load_meta_data(config.datasets, eval_split=True) - - # parse speakers - num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH) - - # setup model - model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim=speaker_embedding_dim) - optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) - criterion = AlignTTSLoss(config) - - if args.restore_path: - print(f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location="cpu") - try: - # TODO: fix optimizer init, model.cuda() needs to be called before - # optimizer restore - optimizer.load_state_dict(checkpoint["optimizer"]) - if config.reinit_layers: - raise RuntimeError - model.load_state_dict(checkpoint["model"]) - except: # pylint: disable=bare-except - print(" > Partial model initialization.") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], config) - model.load_state_dict(model_dict) - del model_dict - - for group in optimizer.param_groups: - group["initial_lr"] = config.lr - print(" > Model restored from step %d" % checkpoint["step"], flush=True) - args.restore_step = checkpoint["step"] - else: - args.restore_step = 0 - - if use_cuda: - model.cuda() - criterion.cuda() - - # DISTRUBUTED - if num_gpus > 1: - model = DDP_th(model, device_ids=[args.rank]) - - if config.noam_schedule: - scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1) - else: - scheduler = None - - num_params = count_parameters(model) - print("\n > Model has {} parameters".format(num_params), flush=True) - - if args.restore_step == 0 or not args.best_path: - best_loss = float("inf") - print(" > Starting with inf best loss.") - else: - print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = config.keep_all_best - keep_after = config.keep_after # void if keep_all_best False - - # define dataloaders - train_loader = setup_loader(ap, 1, is_val=False, verbose=True) - eval_loader = setup_loader(ap, 1, is_val=True, verbose=True) - - global_step = args.restore_step - - def set_phase(): - """Set AlignTTS training phase""" - if isinstance(config.phase_start_steps, list): - vals = [i < global_step for i in config.phase_start_steps] - if not True in vals: - phase = 0 - else: - phase = ( - len(config.phase_start_steps) - - [i < global_step for i in config.phase_start_steps][::-1].index(True) - - 1 - ) - else: - phase = None - return phase - - for epoch in range(0, config.epochs): - cur_phase = set_phase() - print(f"\n > Current AlignTTS phase: {cur_phase}") - c_logger.print_epoch_start(epoch, config.epochs) - train_avg_loss_dict, global_step = train( - train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, cur_phase - ) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch, cur_phase) - c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict["avg_loss"] - if config.run_eval: - target_loss = eval_avg_loss_dict["avg_loss"] - best_loss = save_best_model( - target_loss, - best_loss, - model, - optimizer, - global_step, - epoch, - 1, - OUT_PATH, - model_characters, - keep_all_best=keep_all_best, - keep_after=keep_after, - ) - - -if __name__ == "__main__": - args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) - - try: - main(args) - except KeyboardInterrupt: - remove_experiment_folder(OUT_PATH) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(OUT_PATH) - traceback.print_exc() - sys.exit(1) diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py deleted file mode 100644 index a138abeb..00000000 --- a/TTS/bin/train_glow_tts.py +++ /dev/null @@ -1,598 +0,0 @@ -#!/usr/bin/env python3 -"""Train Glow TTS model.""" - -import os -import sys -import time -import traceback -from random import randrange - -import torch - -# DISTRIBUTED -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from TTS.tts.datasets import load_meta_data -from TTS.tts.datasets.TTSDataset import TTSDataset -from TTS.tts.layers.losses import GlowTTSLoss -from TTS.tts.models import setup_model -from TTS.tts.utils.io import save_best_model, save_checkpoint -from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import parse_speakers -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.arguments import init_training -from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.radam import RAdam -from TTS.utils.training import NoamLR, setup_torch_training_env - -use_cuda, num_gpus = setup_torch_training_env(True, False) - - -def setup_loader(ap, r, is_val=False, verbose=False): - if is_val and not config.run_eval: - loader = None - else: - dataset = TTSDataset( - r, - config.text_cleaner, - compute_linear_spec=False, - meta_data=meta_data_eval if is_val else meta_data_train, - ap=ap, - tp=config.characters, - add_blank=config["add_blank"], - batch_group_size=0 if is_val else config.batch_group_size * config.batch_size, - min_seq_len=config.min_seq_len, - max_seq_len=config.max_seq_len, - phoneme_cache_path=config.phoneme_cache_path, - use_phonemes=config.use_phonemes, - phoneme_language=config.phoneme_language, - enable_eos_bos=config.enable_eos_bos_chars, - use_noise_augment=not is_val, - verbose=verbose, - speaker_mapping=speaker_mapping - if config.use_speaker_embedding and config.use_external_speaker_embedding_file - else None, - ) - - if config.use_phonemes and config.compute_input_seq_cache: - # precompute phonemes to have a better estimate of sequence lengths. - dataset.compute_input_seq(config.num_loader_workers) - dataset.sort_items() - - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=config.eval_batch_size if is_val else config.batch_size, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - sampler=sampler, - num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers, - pin_memory=False, - ) - return loader - - -def format_data(data): - # setup input data - text_input = data[0] - text_lengths = data[1] - speaker_names = data[2] - mel_input = data[4].permute(0, 2, 1) # B x D x T - mel_lengths = data[5] - item_idx = data[7] - attn_mask = data[9] - avg_text_length = torch.mean(text_lengths.float()) - avg_spec_length = torch.mean(mel_lengths.float()) - - if config.use_speaker_embedding: - if config.use_external_speaker_embedding_file: - # return precomputed embedding vector - speaker_c = data[8] - else: - # return speaker_id to be used by an embedding layer - speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] - speaker_c = torch.LongTensor(speaker_c) - else: - speaker_c = None - - # dispatch data to GPU - if use_cuda: - text_input = text_input.cuda(non_blocking=True) - text_lengths = text_lengths.cuda(non_blocking=True) - mel_input = mel_input.cuda(non_blocking=True) - mel_lengths = mel_lengths.cuda(non_blocking=True) - if speaker_c is not None: - speaker_c = speaker_c.cuda(non_blocking=True) - if attn_mask is not None: - attn_mask = attn_mask.cuda(non_blocking=True) - return ( - text_input, - text_lengths, - mel_input, - mel_lengths, - speaker_c, - avg_text_length, - avg_spec_length, - attn_mask, - item_idx, - ) - - -def data_depended_init(data_loader, model): - """Data depended initialization for activation normalization.""" - if hasattr(model, "module"): - for f in model.module.decoder.flows: - if getattr(f, "set_ddi", False): - f.set_ddi(True) - else: - for f in model.decoder.flows: - if getattr(f, "set_ddi", False): - f.set_ddi(True) - - model.train() - print(" > Data depended initialization ... ") - num_iter = 0 - with torch.no_grad(): - for _, data in enumerate(data_loader): - - # format data - text_input, text_lengths, mel_input, mel_lengths, spekaer_embed, _, _, attn_mask, _ = format_data(data) - - # forward pass model - _ = model.forward(text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed) - if num_iter == config.data_dep_init_steps: - break - num_iter += 1 - - if hasattr(model, "module"): - for f in model.module.decoder.flows: - if getattr(f, "set_ddi", False): - f.set_ddi(False) - else: - for f in model.decoder.flows: - if getattr(f, "set_ddi", False): - f.set_ddi(False) - return model - - -def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch): - - model.train() - epoch_time = 0 - keep_avg = KeepAverage() - if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus)) - else: - batch_n_iter = int(len(data_loader.dataset) / config.batch_size) - end_time = time.time() - c_logger.print_train_start() - scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - ( - text_input, - text_lengths, - mel_input, - mel_lengths, - speaker_c, - avg_text_length, - avg_spec_length, - attn_mask, - _, - ) = format_data(data) - - loader_time = time.time() - end_time - - global_step += 1 - optimizer.zero_grad() - - # forward pass model - with torch.cuda.amp.autocast(enabled=config.mixed_precision): - z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c - ) - - # compute loss - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths) - - # backward pass with loss scaling - if config.mixed_precision: - scaler.scale(loss_dict["loss"]).backward() - scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) - scaler.step(optimizer) - scaler.update() - else: - loss_dict["loss"].backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) - optimizer.step() - - # setup lr - if config.noam_schedule: - scheduler.step() - - # current_lr - current_lr = optimizer.param_groups[0]["lr"] - - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict["align_error"] = align_error - - step_time = time.time() - start_time - epoch_time += step_time - - # aggregate losses from processes - if num_gpus > 1: - loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus) - loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) - loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - keep_avg.update_values(update_train_values) - - # print training progress - if global_step % config.print_step == 0: - log_dict = { - "avg_spec_length": [avg_spec_length, 1], # value, precision - "avg_text_length": [avg_text_length, 1], - "step_time": [step_time, 4], - "loader_time": [loader_time, 2], - "current_lr": current_lr, - } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # Plot Training Iter Stats - # reduce TB load - if global_step % config.tb_plot_step == 0: - iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} - iter_stats.update(loss_dict) - tb_logger.tb_train_step_stats(global_step, iter_stats) - - if global_step % config.save_step == 0: - if config.checkpoint: - # save model - save_checkpoint( - model, - optimizer, - global_step, - epoch, - 1, - OUT_PATH, - model_characters, - model_loss=loss_dict["loss"], - ) - - # wait all kernels to be completed - torch.cuda.synchronize() - - # Diagnostic visualizations - # direct pass on model for spec predictions - target_speaker = None if speaker_c is None else speaker_c[:1] - - if hasattr(model, "module"): - spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) - else: - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) - - spec_pred = spec_pred.permute(0, 2, 1) - gt_spec = mel_input.permute(0, 2, 1) - const_spec = spec_pred[0].data.cpu().numpy() - gt_spec = gt_spec[0].data.cpu().numpy() - align_img = alignments[0].data.cpu().numpy() - - figures = { - "prediction": plot_spectrogram(const_spec, ap), - "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img), - } - - tb_logger.tb_train_figures(global_step, figures) - - # Sample audio - train_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"]) - end_time = time.time() - - # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) - - # Plot Epoch Stats - if args.rank == 0: - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(keep_avg.avg_values) - tb_logger.tb_train_epoch_stats(global_step, epoch_stats) - if config.tb_model_param_stats: - tb_logger.tb_model_weights(model, global_step) - return keep_avg.avg_values, global_step - - -@torch.no_grad() -def evaluate(data_loader, model, criterion, ap, global_step, epoch): - model.eval() - epoch_time = 0 - keep_avg = KeepAverage() - c_logger.print_eval_start() - if data_loader is not None: - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_c, _, _, attn_mask, _ = format_data(data) - - # forward pass model - z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c - ) - - # compute loss - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths) - - # step time - step_time = time.time() - start_time - epoch_time += step_time - - # compute alignment score - align_error = 1 - alignment_diagonal_score(alignments) - loss_dict["align_error"] = align_error - - # aggregate losses from processes - if num_gpus > 1: - loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus) - loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) - loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - keep_avg.update_values(update_train_values) - - if config.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # Diagnostic visualizations - # direct pass on model for spec predictions - target_speaker = None if speaker_c is None else speaker_c[:1] - if hasattr(model, "module"): - spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) - else: - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) - spec_pred = spec_pred.permute(0, 2, 1) - gt_spec = mel_input.permute(0, 2, 1) - - const_spec = spec_pred[0].data.cpu().numpy() - gt_spec = gt_spec[0].data.cpu().numpy() - align_img = alignments[0].data.cpu().numpy() - - eval_figures = { - "prediction": plot_spectrogram(const_spec, ap), - "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img), - } - - # Sample audio - eval_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"]) - - # Plot Validation Stats - tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) - tb_logger.tb_eval_figures(global_step, eval_figures) - - if args.rank == 0 and epoch >= config.test_delay_epochs: - if config.test_sentences_file: - with open(config.test_sentences_file, "r") as f: - test_sentences = [s.strip() for s in f.readlines()] - else: - test_sentences = [ - "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - "Be a voice, not an echo.", - "I'm sorry Dave. I'm afraid I can't do that.", - "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963.", - ] - - # test sentences - test_audios = {} - test_figures = {} - print(" | > Synthesizing test sentences") - if config.use_speaker_embedding: - if config.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][ - "embedding" - ] - speaker_id = None - else: - speaker_id = 0 - speaker_embedding = None - else: - speaker_id = None - speaker_embedding = None - - style_wav = config.style_wav_for_test - for idx, test_sentence in enumerate(test_sentences): - try: - wav, alignment, _, postnet_output, _, _ = synthesis( - model, - test_sentence, - config, - use_cuda, - ap, - speaker_id=speaker_id, - speaker_embedding=speaker_embedding, - style_wav=style_wav, - truncated=False, - enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument - use_griffin_lim=True, - do_trim_silence=False, - ) - - file_path = os.path.join(AUDIO_PATH, str(global_step)) - os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) - ap.save_wav(wav, file_path) - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) - except: # pylint: disable=bare-except - print(" !! Error creating Test Sentence -", idx) - traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"]) - tb_logger.tb_test_figures(global_step, test_figures) - return keep_avg.avg_values - - -def main(args): # pylint: disable=redefined-outer-name - # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping - # Audio processor - ap = AudioProcessor(**config.audio.to_dict()) - if config.has("characters") and config.characters: - symbols, phonemes = make_symbols(**config.characters.to_dict()) - - # DISTRUBUTED - if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"]) - - # set model characters - model_characters = phonemes if config.use_phonemes else symbols - num_chars = len(model_characters) - - # load data instances - meta_data_train, meta_data_eval = load_meta_data(config.datasets) - - # parse speakers - num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH) - - # setup model - model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim=speaker_embedding_dim) - optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) - criterion = GlowTTSLoss() - - if args.restore_path: - print(f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location="cpu") - try: - # TODO: fix optimizer init, model.cuda() needs to be called before - # optimizer restore - optimizer.load_state_dict(checkpoint["optimizer"]) - model.load_state_dict(checkpoint["model"]) - except: # pylint: disable=bare-except - print(" > Partial model initialization.") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], config) - model.load_state_dict(model_dict) - del model_dict - - for group in optimizer.param_groups: - group["initial_lr"] = config.lr - print(f" > Model restored from step {checkpoint['step']:d}", flush=True) - args.restore_step = checkpoint["step"] - else: - args.restore_step = 0 - - if use_cuda: - model.cuda() - criterion.cuda() - - # DISTRUBUTED - if num_gpus > 1: - model = DDP_th(model, device_ids=[args.rank]) - - if config.noam_schedule: - scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1) - else: - scheduler = None - - num_params = count_parameters(model) - print("\n > Model has {} parameters".format(num_params), flush=True) - - if args.restore_step == 0 or not args.best_path: - best_loss = float("inf") - print(" > Starting with inf best loss.") - else: - print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = config.keep_all_best - keep_after = config.keep_after # void if keep_all_best False - - # define dataloaders - train_loader = setup_loader(ap, 1, is_val=False, verbose=True) - eval_loader = setup_loader(ap, 1, is_val=True, verbose=True) - - global_step = args.restore_step - model = data_depended_init(train_loader, model) - for epoch in range(0, config.epochs): - c_logger.print_epoch_start(epoch, config.epochs) - train_avg_loss_dict, global_step = train( - train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch - ) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) - c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict["avg_loss"] - if config.run_eval: - target_loss = eval_avg_loss_dict["avg_loss"] - best_loss = save_best_model( - target_loss, - best_loss, - model, - optimizer, - global_step, - epoch, - config.r, - OUT_PATH, - model_characters, - keep_all_best=keep_all_best, - keep_after=keep_after, - ) - - -if __name__ == "__main__": - args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) - - try: - main(args) - except KeyboardInterrupt: - remove_experiment_folder(OUT_PATH) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(OUT_PATH) - traceback.print_exc() - sys.exit(1) diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py deleted file mode 100644 index 4dc3f5f0..00000000 --- a/TTS/bin/train_speedy_speech.py +++ /dev/null @@ -1,578 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import os -import sys -import time -import traceback -from random import randrange - -import numpy as np -import torch - -# DISTRIBUTED -from torch.nn.parallel import DistributedDataParallel as DDP_th -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from TTS.tts.datasets import load_meta_data -from TTS.tts.datasets.TTSDataset import TTSDataset -from TTS.tts.layers.losses import SpeedySpeechLoss -from TTS.tts.models import setup_model -from TTS.tts.utils.io import save_best_model, save_checkpoint -from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import parse_speakers -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.arguments import init_training -from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.radam import RAdam -from TTS.utils.training import NoamLR, setup_torch_training_env - -use_cuda, num_gpus = setup_torch_training_env(True, False) - - -def setup_loader(ap, r, is_val=False, verbose=False): - if is_val and not config.run_eval: - loader = None - else: - dataset = TTSDataset( - r, - config.text_cleaner, - compute_linear_spec=False, - meta_data=meta_data_eval if is_val else meta_data_train, - ap=ap, - tp=config.characters, - add_blank=config["add_blank"], - batch_group_size=0 if is_val else config.batch_group_size * config.batch_size, - min_seq_len=config.min_seq_len, - max_seq_len=config.max_seq_len, - phoneme_cache_path=config.phoneme_cache_path, - use_phonemes=config.use_phonemes, - phoneme_language=config.phoneme_language, - enable_eos_bos=config.enable_eos_bos_chars, - use_noise_augment=not is_val, - verbose=verbose, - speaker_mapping=speaker_mapping - if config.use_speaker_embedding and config.use_external_speaker_embedding_file - else None, - ) - - if config.use_phonemes and config.compute_input_seq_cache: - # precompute phonemes to have a better estimate of sequence lengths. - dataset.compute_input_seq(config.num_loader_workers) - dataset.sort_items() - - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=config.eval_batch_size if is_val else config.batch_size, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - sampler=sampler, - num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers, - pin_memory=False, - ) - return loader - - -def format_data(data): - # setup input data - text_input = data[0] - text_lengths = data[1] - speaker_names = data[2] - mel_input = data[4].permute(0, 2, 1) # B x D x T - mel_lengths = data[5] - item_idx = data[7] - attn_mask = data[9] - avg_text_length = torch.mean(text_lengths.float()) - avg_spec_length = torch.mean(mel_lengths.float()) - - if config.use_speaker_embedding: - if config.use_external_speaker_embedding_file: - # return precomputed embedding vector - speaker_c = data[8] - else: - # return speaker_id to be used by an embedding layer - speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] - speaker_c = torch.LongTensor(speaker_c) - else: - speaker_c = None - # compute durations from attention mask - durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) - for idx, am in enumerate(attn_mask): - # compute raw durations - c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] - # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) - c_idxs, counts = torch.unique(c_idxs, return_counts=True) - dur = torch.ones([text_lengths[idx]]).to(counts.dtype) - dur[c_idxs] = counts - # smooth the durations and set any 0 duration to 1 - # by cutting off from the largest duration indeces. - extra_frames = dur.sum() - mel_lengths[idx] - largest_idxs = torch.argsort(-dur)[:extra_frames] - dur[largest_idxs] -= 1 - assert ( - dur.sum() == mel_lengths[idx] - ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" - durations[idx, : text_lengths[idx]] = dur - # dispatch data to GPU - if use_cuda: - text_input = text_input.cuda(non_blocking=True) - text_lengths = text_lengths.cuda(non_blocking=True) - mel_input = mel_input.cuda(non_blocking=True) - mel_lengths = mel_lengths.cuda(non_blocking=True) - if speaker_c is not None: - speaker_c = speaker_c.cuda(non_blocking=True) - attn_mask = attn_mask.cuda(non_blocking=True) - durations = durations.cuda(non_blocking=True) - return ( - text_input, - text_lengths, - mel_input, - mel_lengths, - speaker_c, - avg_text_length, - avg_spec_length, - attn_mask, - durations, - item_idx, - ) - - -def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch): - - model.train() - epoch_time = 0 - keep_avg = KeepAverage() - if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus)) - else: - batch_n_iter = int(len(data_loader.dataset) / config.batch_size) - end_time = time.time() - c_logger.print_train_start() - scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - ( - text_input, - text_lengths, - mel_targets, - mel_lengths, - speaker_c, - avg_text_length, - avg_spec_length, - _, - dur_target, - _, - ) = format_data(data) - - loader_time = time.time() - end_time - - global_step += 1 - optimizer.zero_grad() - - # forward pass model - with torch.cuda.amp.autocast(enabled=config.mixed_precision): - decoder_output, dur_output, alignments = model.forward( - text_input, text_lengths, mel_lengths, dur_target, g=speaker_c - ) - - # compute loss - loss_dict = criterion( - decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths - ) - - # backward pass with loss scaling - if config.mixed_precision: - scaler.scale(loss_dict["loss"]).backward() - scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) - scaler.step(optimizer) - scaler.update() - else: - loss_dict["loss"].backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) - optimizer.step() - - # setup lr - if config.noam_schedule: - scheduler.step() - - # current_lr - current_lr = optimizer.param_groups[0]["lr"] - - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict["align_error"] = align_error - - step_time = time.time() - start_time - epoch_time += step_time - - # aggregate losses from processes - if num_gpus > 1: - loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) - loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) - loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) - loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - keep_avg.update_values(update_train_values) - - # print training progress - if global_step % config.print_step == 0: - log_dict = { - "avg_spec_length": [avg_spec_length, 1], # value, precision - "avg_text_length": [avg_text_length, 1], - "step_time": [step_time, 4], - "loader_time": [loader_time, 2], - "current_lr": current_lr, - } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # Plot Training Iter Stats - # reduce TB load - if global_step % config.tb_plot_step == 0: - iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} - iter_stats.update(loss_dict) - tb_logger.tb_train_step_stats(global_step, iter_stats) - - if global_step % config.save_step == 0: - if config.checkpoint: - # save model - save_checkpoint( - model, - optimizer, - global_step, - epoch, - 1, - OUT_PATH, - model_characters, - model_loss=loss_dict["loss"], - ) - - # wait all kernels to be completed - torch.cuda.synchronize() - - # Diagnostic visualizations - idx = np.random.randint(mel_targets.shape[0]) - pred_spec = decoder_output[idx].detach().data.cpu().numpy().T - gt_spec = mel_targets[idx].data.cpu().numpy().T - align_img = alignments[idx].data.cpu() - - figures = { - "prediction": plot_spectrogram(pred_spec, ap), - "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img), - } - - tb_logger.tb_train_figures(global_step, figures) - - # Sample audio - train_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"]) - end_time = time.time() - - # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) - - # Plot Epoch Stats - if args.rank == 0: - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(keep_avg.avg_values) - tb_logger.tb_train_epoch_stats(global_step, epoch_stats) - if config.tb_model_param_stats: - tb_logger.tb_model_weights(model, global_step) - return keep_avg.avg_values, global_step - - -@torch.no_grad() -def evaluate(data_loader, model, criterion, ap, global_step, epoch): - model.eval() - epoch_time = 0 - keep_avg = KeepAverage() - c_logger.print_eval_start() - if data_loader is not None: - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _, dur_target, _ = format_data(data) - - # forward pass model - with torch.cuda.amp.autocast(enabled=config.mixed_precision): - decoder_output, dur_output, alignments = model.forward( - text_input, text_lengths, mel_lengths, dur_target, g=speaker_c - ) - - # compute loss - loss_dict = criterion( - decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths - ) - - # step time - step_time = time.time() - start_time - epoch_time += step_time - - # compute alignment score - align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict["align_error"] = align_error - - # aggregate losses from processes - if num_gpus > 1: - loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) - loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) - loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) - loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - keep_avg.update_values(update_train_values) - - if config.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # Diagnostic visualizations - idx = np.random.randint(mel_targets.shape[0]) - pred_spec = decoder_output[idx].detach().data.cpu().numpy().T - gt_spec = mel_targets[idx].data.cpu().numpy().T - align_img = alignments[idx].data.cpu() - - eval_figures = { - "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), - "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False), - } - - # Sample audio - eval_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"]) - - # Plot Validation Stats - tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) - tb_logger.tb_eval_figures(global_step, eval_figures) - - if args.rank == 0 and epoch >= config.test_delay_epochs: - if config.test_sentences_file: - with open(config.test_sentences_file, "r") as f: - test_sentences = [s.strip() for s in f.readlines()] - else: - test_sentences = [ - "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - "Be a voice, not an echo.", - "I'm sorry Dave. I'm afraid I can't do that.", - "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963.", - ] - - # test sentences - test_audios = {} - test_figures = {} - print(" | > Synthesizing test sentences") - if config.use_speaker_embedding: - if config.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][ - "embedding" - ] - speaker_id = None - else: - speaker_id = 0 - speaker_embedding = None - else: - speaker_id = None - speaker_embedding = None - - for idx, test_sentence in enumerate(test_sentences): - try: - wav, alignment, _, postnet_output, _, _ = synthesis( - model, - test_sentence, - config, - use_cuda, - ap, - speaker_id=speaker_id, - speaker_embedding=speaker_embedding, - style_wav=None, - truncated=False, - enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument - use_griffin_lim=True, - do_trim_silence=False, - ) - - file_path = os.path.join(AUDIO_PATH, str(global_step)) - os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) - ap.save_wav(wav, file_path) - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) - except: # pylint: disable=bare-except - print(" !! Error creating Test Sentence -", idx) - traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"]) - tb_logger.tb_test_figures(global_step, test_figures) - return keep_avg.avg_values - - -# FIXME: move args definition/parsing inside of main? -def main(args): # pylint: disable=redefined-outer-name - # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping - # Audio processor - ap = AudioProcessor(**config.audio.to_dict()) - if config.characters is not None: - symbols, phonemes = make_symbols(**config.characters.to_dict()) - - # DISTRUBUTED - if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"]) - - # set model characters - model_characters = phonemes if config.use_phonemes else symbols - num_chars = len(model_characters) - - # load data instances - meta_data_train, meta_data_eval = load_meta_data(config.datasets, eval_split=True) - - # set the portion of the data used for training if set in config.json - if config.has("train_portion"): - meta_data_train = meta_data_train[: int(len(meta_data_train) * config.train_portion)] - if config.has("eval_portion"): - meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * config.eval_portion)] - - # parse speakers - num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH) - - # setup model - model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim=speaker_embedding_dim) - optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) - criterion = SpeedySpeechLoss(config) - - if args.restore_path: - print(f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location="cpu") - try: - # TODO: fix optimizer init, model.cuda() needs to be called before - # optimizer restore - optimizer.load_state_dict(checkpoint["optimizer"]) - if config.reinit_layers: - raise RuntimeError - model.load_state_dict(checkpoint["model"]) - except: # pylint: disable=bare-except - print(" > Partial model initialization.") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], config) - model.load_state_dict(model_dict) - del model_dict - - for group in optimizer.param_groups: - group["initial_lr"] = config.lr - print(" > Model restored from step %d" % checkpoint["step"], flush=True) - args.restore_step = checkpoint["step"] - else: - args.restore_step = 0 - - if use_cuda: - model.cuda() - criterion.cuda() - - # DISTRUBUTED - if num_gpus > 1: - model = DDP_th(model, device_ids=[args.rank]) - - if config.noam_schedule: - scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1) - else: - scheduler = None - - num_params = count_parameters(model) - print("\n > Model has {} parameters".format(num_params), flush=True) - - if args.restore_step == 0 or not args.best_path: - best_loss = float("inf") - print(" > Starting with inf best loss.") - else: - print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = config.keep_all_best - keep_after = config.keep_after # void if keep_all_best False - - # define dataloaders - train_loader = setup_loader(ap, 1, is_val=False, verbose=True) - eval_loader = setup_loader(ap, 1, is_val=True, verbose=True) - - global_step = args.restore_step - for epoch in range(0, config.epochs): - c_logger.print_epoch_start(epoch, config.epochs) - train_avg_loss_dict, global_step = train( - train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch - ) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) - c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict["avg_loss"] - if config.run_eval: - target_loss = eval_avg_loss_dict["avg_loss"] - best_loss = save_best_model( - target_loss, - best_loss, - model, - optimizer, - global_step, - epoch, - config.r, - OUT_PATH, - model_characters, - keep_all_best=keep_all_best, - keep_after=keep_after, - ) - - -if __name__ == "__main__": - args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) - - try: - main(args) - except KeyboardInterrupt: - remove_experiment_folder(OUT_PATH) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(OUT_PATH) - traceback.print_exc() - sys.exit(1) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py deleted file mode 100755 index 69ffbb6c..00000000 --- a/TTS/bin/train_tacotron.py +++ /dev/null @@ -1,749 +0,0 @@ -#!/usr/bin/env python3 -"""Trains Tacotron based TTS models.""" - -import os -import sys -import time -import traceback -from random import randrange - -import numpy as np -import torch -from torch.utils.data import DataLoader - -from TTS.tts.datasets import load_meta_data -from TTS.tts.datasets.TTSDataset import TTSDataset -from TTS.tts.layers.losses import TacotronLoss -from TTS.tts.models import setup_model -from TTS.tts.utils.io import save_best_model, save_checkpoint -from TTS.tts.utils.measures import alignment_diagonal_score -from TTS.tts.utils.speakers import parse_speakers -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.arguments import init_training -from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict -from TTS.utils.radam import RAdam -from TTS.utils.training import ( - NoamLR, - adam_weight_decay, - check_update, - gradual_training_scheduler, - set_weight_decay, - setup_torch_training_env, -) - -use_cuda, num_gpus = setup_torch_training_env(True, False) - - -def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): - if is_val and not config.run_eval: - loader = None - else: - if dataset is None: - dataset = TTSDataset( - r, - config.text_cleaner, - compute_linear_spec=config.model.lower() == "tacotron", - meta_data=meta_data_eval if is_val else meta_data_train, - ap=ap, - tp=config.characters, - add_blank=config["add_blank"], - batch_group_size=0 if is_val else config.batch_group_size * config.batch_size, - min_seq_len=config.min_seq_len, - max_seq_len=config.max_seq_len, - phoneme_cache_path=config.phoneme_cache_path, - use_phonemes=config.use_phonemes, - phoneme_language=config.phoneme_language, - enable_eos_bos=config.enable_eos_bos_chars, - verbose=verbose, - speaker_mapping=( - speaker_mapping - if (config.use_speaker_embedding and config.use_external_speaker_embedding_file) - else None - ), - ) - - if config.use_phonemes and config.compute_input_seq_cache: - # precompute phonemes to have a better estimate of sequence lengths. - dataset.compute_input_seq(config.num_loader_workers) - dataset.sort_items() - - sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader( - dataset, - batch_size=config.eval_batch_size if is_val else config.batch_size, - shuffle=False, - collate_fn=dataset.collate_fn, - drop_last=False, - sampler=sampler, - num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers, - pin_memory=False, - ) - return loader - - -def format_data(data): - # setup input data - text_input = data[0] - text_lengths = data[1] - speaker_names = data[2] - linear_input = data[3] if config.model.lower() in ["tacotron"] else None - mel_input = data[4] - mel_lengths = data[5] - stop_targets = data[6] - max_text_length = torch.max(text_lengths.float()) - max_spec_length = torch.max(mel_lengths.float()) - - if config.use_speaker_embedding: - if config.use_external_speaker_embedding_file: - speaker_embeddings = data[8] - speaker_ids = None - else: - speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names] - speaker_ids = torch.LongTensor(speaker_ids) - speaker_embeddings = None - else: - speaker_embeddings = None - speaker_ids = None - - # set stop targets view, we predict a single stop token per iteration. - stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // config.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) - - # dispatch data to GPU - if use_cuda: - text_input = text_input.cuda(non_blocking=True) - text_lengths = text_lengths.cuda(non_blocking=True) - mel_input = mel_input.cuda(non_blocking=True) - mel_lengths = mel_lengths.cuda(non_blocking=True) - linear_input = linear_input.cuda(non_blocking=True) if config.model.lower() in ["tacotron"] else None - stop_targets = stop_targets.cuda(non_blocking=True) - if speaker_ids is not None: - speaker_ids = speaker_ids.cuda(non_blocking=True) - if speaker_embeddings is not None: - speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) - - return ( - text_input, - text_lengths, - mel_input, - mel_lengths, - linear_input, - stop_targets, - speaker_ids, - speaker_embeddings, - max_text_length, - max_spec_length, - ) - - -def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st): - model.train() - epoch_time = 0 - keep_avg = KeepAverage() - if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus)) - else: - batch_n_iter = int(len(data_loader.dataset) / config.batch_size) - end_time = time.time() - c_logger.print_train_start() - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - ( - text_input, - text_lengths, - mel_input, - mel_lengths, - linear_input, - stop_targets, - speaker_ids, - speaker_embeddings, - max_text_length, - max_spec_length, - ) = format_data(data) - loader_time = time.time() - end_time - - global_step += 1 - - # setup lr - if config.noam_schedule: - scheduler.step() - - optimizer.zero_grad() - if optimizer_st: - optimizer_st.zero_grad() - - with torch.cuda.amp.autocast(enabled=config.mixed_precision): - # forward pass model - if config.bidirectional_decoder or config.double_decoder_consistency: - ( - decoder_output, - postnet_output, - alignments, - stop_tokens, - decoder_backward_output, - alignments_backward, - ) = model( - text_input, - text_lengths, - mel_input, - mel_lengths, - speaker_ids=speaker_ids, - speaker_embeddings=speaker_embeddings, - ) - else: - decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, - text_lengths, - mel_input, - mel_lengths, - speaker_ids=speaker_ids, - speaker_embeddings=speaker_embeddings, - ) - decoder_backward_output = None - alignments_backward = None - - # set the [alignment] lengths wrt reduction factor for guided attention - if mel_lengths.max() % model.decoder.r != 0: - alignment_lengths = ( - mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r)) - ) // model.decoder.r - else: - alignment_lengths = mel_lengths // model.decoder.r - - # compute loss - loss_dict = criterion( - postnet_output, - decoder_output, - mel_input, - linear_input, - stop_tokens, - stop_targets, - mel_lengths, - decoder_backward_output, - alignments, - alignment_lengths, - alignments_backward, - text_lengths, - ) - - # check nan loss - if torch.isnan(loss_dict["loss"]).any(): - raise RuntimeError(f"Detected NaN loss at step {global_step}.") - - # optimizer step - if config.mixed_precision: - # model optimizer step in mixed precision mode - scaler.scale(loss_dict["loss"]).backward() - scaler.unscale_(optimizer) - optimizer, current_lr = adam_weight_decay(optimizer) - grad_norm, _ = check_update(model, config.grad_clip, ignore_stopnet=True) - scaler.step(optimizer) - scaler.update() - - # stopnet optimizer step - if config.separate_stopnet: - scaler_st.scale(loss_dict["stopnet_loss"]).backward() - scaler.unscale_(optimizer_st) - optimizer_st, _ = adam_weight_decay(optimizer_st) - grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) - scaler_st.step(optimizer) - scaler_st.update() - else: - grad_norm_st = 0 - else: - # main model optimizer step - loss_dict["loss"].backward() - optimizer, current_lr = adam_weight_decay(optimizer) - grad_norm, _ = check_update(model, config.grad_clip, ignore_stopnet=True) - optimizer.step() - - # stopnet optimizer step - if config.separate_stopnet: - loss_dict["stopnet_loss"].backward() - optimizer_st, _ = adam_weight_decay(optimizer_st) - grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) - optimizer_st.step() - else: - grad_norm_st = 0 - - # compute alignment error (the lower the better ) - align_error = 1 - alignment_diagonal_score(alignments) - loss_dict["align_error"] = align_error - - step_time = time.time() - start_time - epoch_time += step_time - - # aggregate losses from processes - if num_gpus > 1: - loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus) - loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus) - loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) - loss_dict["stopnet_loss"] = ( - reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) if config.stopnet else loss_dict["stopnet_loss"] - ) - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - update_train_values["avg_loader_time"] = loader_time - update_train_values["avg_step_time"] = step_time - keep_avg.update_values(update_train_values) - - # print training progress - if global_step % config.print_step == 0: - log_dict = { - "max_spec_length": [max_spec_length, 1], # value, precision - "max_text_length": [max_text_length, 1], - "step_time": [step_time, 4], - "loader_time": [loader_time, 2], - "current_lr": current_lr, - } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # Plot Training Iter Stats - # reduce TB load - if global_step % config.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "grad_norm_st": grad_norm_st, - "step_time": step_time, - } - iter_stats.update(loss_dict) - tb_logger.tb_train_step_stats(global_step, iter_stats) - - if global_step % config.save_step == 0: - if config.checkpoint: - # save model - save_checkpoint( - model, - optimizer, - global_step, - epoch, - model.decoder.r, - OUT_PATH, - optimizer_st=optimizer_st, - model_loss=loss_dict["postnet_loss"], - characters=model_characters, - scaler=scaler.state_dict() if config.mixed_precision else None, - ) - - # Diagnostic visualizations - const_spec = postnet_output[0].data.cpu().numpy() - gt_spec = ( - linear_input[0].data.cpu().numpy() - if config.model in ["Tacotron", "TacotronGST"] - else mel_input[0].data.cpu().numpy() - ) - align_img = alignments[0].data.cpu().numpy() - - figures = { - "prediction": plot_spectrogram(const_spec, ap, output_fig=False), - "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False), - } - - if config.bidirectional_decoder or config.double_decoder_consistency: - figures["alignment_backward"] = plot_alignment( - alignments_backward[0].data.cpu().numpy(), output_fig=False - ) - - tb_logger.tb_train_figures(global_step, figures) - - # Sample audio - if config.model in ["Tacotron", "TacotronGST"]: - train_audio = ap.inv_spectrogram(const_spec.T) - else: - train_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"]) - end_time = time.time() - - # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) - - # Plot Epoch Stats - if args.rank == 0: - epoch_stats = {"epoch_time": epoch_time} - epoch_stats.update(keep_avg.avg_values) - tb_logger.tb_train_epoch_stats(global_step, epoch_stats) - if config.tb_model_param_stats: - tb_logger.tb_model_weights(model, global_step) - return keep_avg.avg_values, global_step - - -@torch.no_grad() -def evaluate(data_loader, model, criterion, ap, global_step, epoch): - model.eval() - epoch_time = 0 - keep_avg = KeepAverage() - c_logger.print_eval_start() - if data_loader is not None: - for num_iter, data in enumerate(data_loader): - start_time = time.time() - - # format data - ( - text_input, - text_lengths, - mel_input, - mel_lengths, - linear_input, - stop_targets, - speaker_ids, - speaker_embeddings, - _, - _, - ) = format_data(data) - assert mel_input.shape[1] % model.decoder.r == 0 - - # forward pass model - if config.bidirectional_decoder or config.double_decoder_consistency: - ( - decoder_output, - postnet_output, - alignments, - stop_tokens, - decoder_backward_output, - alignments_backward, - ) = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings - ) - else: - decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings - ) - decoder_backward_output = None - alignments_backward = None - - # set the alignment lengths wrt reduction factor for guided attention - if mel_lengths.max() % model.decoder.r != 0: - alignment_lengths = ( - mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r)) - ) // model.decoder.r - else: - alignment_lengths = mel_lengths // model.decoder.r - - # compute loss - loss_dict = criterion( - postnet_output, - decoder_output, - mel_input, - linear_input, - stop_tokens, - stop_targets, - mel_lengths, - decoder_backward_output, - alignments, - alignment_lengths, - alignments_backward, - text_lengths, - ) - - # step time - step_time = time.time() - start_time - epoch_time += step_time - - # compute alignment score - align_error = 1 - alignment_diagonal_score(alignments) - loss_dict["align_error"] = align_error - - # aggregate losses from processes - if num_gpus > 1: - loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus) - loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus) - if config.stopnet: - loss_dict["stopnet_loss"] = reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) - - # detach loss values - loss_dict_new = dict() - for key, value in loss_dict.items(): - if isinstance(value, (int, float)): - loss_dict_new[key] = value - else: - loss_dict_new[key] = value.item() - loss_dict = loss_dict_new - - # update avg stats - update_train_values = dict() - for key, value in loss_dict.items(): - update_train_values["avg_" + key] = value - keep_avg.update_values(update_train_values) - - if config.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) - - if args.rank == 0: - # Diagnostic visualizations - idx = np.random.randint(mel_input.shape[0]) - const_spec = postnet_output[idx].data.cpu().numpy() - gt_spec = ( - linear_input[idx].data.cpu().numpy() - if config.model in ["Tacotron", "TacotronGST"] - else mel_input[idx].data.cpu().numpy() - ) - align_img = alignments[idx].data.cpu().numpy() - - eval_figures = { - "prediction": plot_spectrogram(const_spec, ap, output_fig=False), - "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False), - } - - # Sample audio - if config.model.lower() in ["tacotron"]: - eval_audio = ap.inv_spectrogram(const_spec.T) - else: - eval_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"]) - - # Plot Validation Stats - - if config.bidirectional_decoder or config.double_decoder_consistency: - align_b_img = alignments_backward[idx].data.cpu().numpy() - eval_figures["alignment2"] = plot_alignment(align_b_img, output_fig=False) - tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) - tb_logger.tb_eval_figures(global_step, eval_figures) - - if args.rank == 0 and epoch > config.test_delay_epochs: - if config.test_sentences_file: - with open(config.test_sentences_file, "r") as f: - test_sentences = [s.strip() for s in f.readlines()] - else: - test_sentences = [ - "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", - "Be a voice, not an echo.", - "I'm sorry Dave. I'm afraid I can't do that.", - "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963.", - ] - - # test sentences - test_audios = {} - test_figures = {} - print(" | > Synthesizing test sentences") - speaker_id = 0 if config.use_speaker_embedding else None - speaker_embedding = ( - speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]]["embedding"] - if config.use_external_speaker_embedding_file and config.use_speaker_embedding - else None - ) - style_wav = config.gst_style_input - if style_wav is None and config.gst is not None: - # inicialize GST with zero dict. - style_wav = {} - print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") - for i in range(config.gst["gst_num_style_tokens"]): - style_wav[str(i)] = 0 - for idx, test_sentence in enumerate(test_sentences): - try: - wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis( - model, - test_sentence, - config, - use_cuda, - ap, - speaker_id=speaker_id, - speaker_embedding=speaker_embedding, - style_wav=style_wav, - truncated=False, - enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument - use_griffin_lim=True, - do_trim_silence=False, - ) - - file_path = os.path.join(AUDIO_PATH, str(global_step)) - os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) - ap.save_wav(wav, file_path) - test_audios["{}-audio".format(idx)] = wav - test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap, output_fig=False) - test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) - except: # pylint: disable=bare-except - print(" !! Error creating Test Sentence -", idx) - traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"]) - tb_logger.tb_test_figures(global_step, test_figures) - return keep_avg.avg_values - - -def main(args): # pylint: disable=redefined-outer-name - # pylint: disable=global-variable-undefined - global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters - # Audio processor - ap = AudioProcessor(**config.audio.to_dict()) - - # setup custom characters if set in config file. - if config.characters is not None: - symbols, phonemes = make_symbols(**config.characters.to_dict()) - - # DISTRUBUTED - if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"]) - num_chars = len(phonemes) if config.use_phonemes else len(symbols) - model_characters = phonemes if config.use_phonemes else symbols - - # load data instances - meta_data_train, meta_data_eval = load_meta_data(config.datasets) - - # set the portion of the data used for training - if config.has("train_portion"): - meta_data_train = meta_data_train[: int(len(meta_data_train) * config.train_portion)] - if config.has("eval_portion"): - meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * config.eval_portion)] - - # parse speakers - num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH) - - model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim) - - # scalers for mixed precision training - scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None - scaler_st = torch.cuda.amp.GradScaler() if config.mixed_precision and config.separate_stopnet else None - - params = set_weight_decay(model, config.wd) - optimizer = RAdam(params, lr=config.lr, weight_decay=0) - if config.stopnet and config.separate_stopnet: - optimizer_st = RAdam(model.decoder.stopnet.parameters(), lr=config.lr, weight_decay=0) - else: - optimizer_st = None - - # setup criterion - criterion = TacotronLoss(config, stopnet_pos_weight=config.stopnet_pos_weight, ga_sigma=0.4) - if args.restore_path: - print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location="cpu") - try: - print(" > Restoring Model...") - model.load_state_dict(checkpoint["model"]) - # optimizer restore - print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint["optimizer"]) - if "scaler" in checkpoint and config.mixed_precision: - print(" > Restoring AMP Scaler...") - scaler.load_state_dict(checkpoint["scaler"]) - except (KeyError, RuntimeError): - print(" > Partial model initialization...") - model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint["model"], config) - model.load_state_dict(model_dict) - del model_dict - - for group in optimizer.param_groups: - group["lr"] = config.lr - print(" > Model restored from step %d" % checkpoint["step"], flush=True) - args.restore_step = checkpoint["step"] - else: - args.restore_step = 0 - - if use_cuda: - model.cuda() - criterion.cuda() - - # DISTRUBUTED - if num_gpus > 1: - model = apply_gradient_allreduce(model) - - if config.noam_schedule: - scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1) - else: - scheduler = None - - num_params = count_parameters(model) - print("\n > Model has {} parameters".format(num_params), flush=True) - - if args.restore_step == 0 or not args.best_path: - best_loss = float("inf") - print(" > Starting with inf best loss.") - else: - print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] - print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = config.keep_all_best - keep_after = config.keep_after # void if keep_all_best False - - # define data loaders - train_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=True) - eval_loader = setup_loader(ap, model.decoder.r, is_val=True) - - global_step = args.restore_step - for epoch in range(0, config.epochs): - c_logger.print_epoch_start(epoch, config.epochs) - # set gradual training - if config.gradual_training is not None: - r, config.batch_size = gradual_training_scheduler(global_step, config) - config.r = r - model.decoder.set_r(r) - if config.bidirectional_decoder: - model.decoder_backward.set_r(r) - train_loader.dataset.outputs_per_step = r - eval_loader.dataset.outputs_per_step = r - train_loader = setup_loader(ap, model.decoder.r, is_val=False, dataset=train_loader.dataset) - eval_loader = setup_loader(ap, model.decoder.r, is_val=True, dataset=eval_loader.dataset) - print("\n > Number of output frames:", model.decoder.r) - # train one epoch - train_avg_loss_dict, global_step = train( - train_loader, - model, - criterion, - optimizer, - optimizer_st, - scheduler, - ap, - global_step, - epoch, - scaler, - scaler_st, - ) - # eval one epoch - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) - c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict["avg_postnet_loss"] - if config.run_eval: - target_loss = eval_avg_loss_dict["avg_postnet_loss"] - best_loss = save_best_model( - target_loss, - best_loss, - model, - optimizer, - global_step, - epoch, - config.r, - OUT_PATH, - model_characters, - keep_all_best=keep_all_best, - keep_after=keep_after, - scaler=scaler.state_dict() if config.mixed_precision else None, - ) - - -if __name__ == "__main__": - args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv) - try: - main(args) - except KeyboardInterrupt: - remove_experiment_folder(OUT_PATH) - try: - sys.exit(0) - except SystemExit: - os._exit(0) # pylint: disable=protected-access - except Exception: # pylint: disable=broad-except - remove_experiment_folder(OUT_PATH) - traceback.print_exc() - sys.exit(1)