diff --git a/train.py b/train.py index 30133b96..1100c1f3 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,7 @@ from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters, load_config, remove_experiment_folder, save_best_model, save_checkpoint, weight_decay, set_init_dict, copy_config_file, setup_model, - split_dataset, gradual_training_scheduler) + split_dataset, gradual_training_scheduler, KeepAverage) from TTS.utils.logger import Logger from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ get_speakers @@ -29,6 +29,7 @@ from TTS.utils.text.symbols import phonemes, symbols from TTS.utils.visual import plot_alignment, plot_spectrogram from TTS.datasets.preprocess import get_preprocessor_by_name from TTS.utils.radam import RAdam +from TTS.utils.measures import alignment_diagonal_score torch.backends.cudnn.enabled = True @@ -45,12 +46,14 @@ def setup_loader(ap, is_val=False, verbose=False): global meta_data_eval if "meta_data_train" not in globals(): if c.meta_file_train is not None: - meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train) + meta_data_train = get_preprocessor_by_name( + c.dataset)(c.data_path, c.meta_file_train) else: meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path) if "meta_data_eval" not in globals() and c.run_eval: if c.meta_file_val is not None: - meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_val) + meta_data_eval = get_preprocessor_by_name( + c.dataset)(c.data_path, c.meta_file_val) else: meta_data_eval, meta_data_train = split_dataset(meta_data_train) if is_val and not c.run_eval: @@ -90,14 +93,20 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, speaker_mapping = load_speaker_mapping(OUT_PATH) model.train() epoch_time = 0 - avg_postnet_loss = 0 - avg_decoder_loss = 0 - avg_stop_loss = 0 - avg_step_time = 0 - avg_loader_time = 0 + train_values = { + 'avg_postnet_loss': 0, + 'avg_decoder_loss': 0, + 'avg_stop_loss': 0, + 'avg_align_score': 0, + 'avg_step_time': 0, + 'avg_loader_time': 0, + 'avg_alignment_score': 0} + keep_avg = KeepAverage() + keep_avg.add_values(train_values) print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / + (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -108,7 +117,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, text_input = data[0] text_lengths = data[1] speaker_names = data[2] - linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None + linear_input = data[3] if c.model in [ + "Tacotron", "TacotronGST"] else None mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] @@ -126,7 +136,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze( + 2).float().squeeze(2) global_step += 1 @@ -143,7 +154,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) - linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None + linear_input = linear_input.cuda(non_blocking=True) if c.model in [ + "Tacotron", "TacotronGST"] else None stop_targets = stop_targets.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) @@ -153,13 +165,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, text_input, text_lengths, mel_input, speaker_ids=speaker_ids) # loss computation - stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) + stop_loss = criterion_st( + stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) if c.loss_masking: decoder_loss = criterion(decoder_output, mel_input, mel_lengths) if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input, mel_lengths) + postnet_loss = criterion( + postnet_output, linear_input, mel_lengths) else: - postnet_loss = criterion(postnet_output, mel_input, mel_lengths) + postnet_loss = criterion( + postnet_output, mel_input, mel_lengths) else: decoder_loss = criterion(decoder_output, mel_input) if c.model in ["Tacotron", "TacotronGST"]: @@ -175,6 +190,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() + # compute alignment score + align_score = alignment_diagonal_score(alignments) + keep_avg.update_value('avg_align_score', align_score) + # backpass and check the grad norm for stop loss if c.separate_stopnet: stop_loss.backward() @@ -183,18 +202,18 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, optimizer_st.step() else: grad_norm_st = 0 - + step_time = time.time() - start_time epoch_time += step_time if global_step % c.print_step == 0: print( - " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} " - "DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " + " | > Step:{}/{} GlobalStep:{} PostnetLoss:{:.5f} " + "DecoderLoss:{:.5f} StopLoss:{:.5f} AlignScore:{:.4f} GradNorm:{:.5f} " "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} " "LoaderTime:{:.2f} LR:{:.6f}".format( - num_iter, batch_n_iter, global_step, loss.item(), - postnet_loss.item(), decoder_loss.item(), stop_loss.item(), + num_iter, batch_n_iter, global_step, + postnet_loss.item(), decoder_loss.item(), stop_loss.item(), align_score.item(), grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, loader_time, current_lr), flush=True) @@ -204,14 +223,16 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) loss = reduce_tensor(loss.data, num_gpus) - stop_loss = reduce_tensor(stop_loss.data, num_gpus) if c.stopnet else stop_loss + stop_loss = reduce_tensor( + stop_loss.data, num_gpus) if c.stopnet else stop_loss if args.rank == 0: - avg_postnet_loss += float(postnet_loss.item()) - avg_decoder_loss += float(decoder_loss.item()) - avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item()) - avg_step_time += step_time - avg_loader_time += loader_time + update_train_values = {'avg_postnet_loss': float(postnet_loss.item()), + 'avg_decoder_loss': float(decoder_loss.item()), + 'avg_stop_loss': stop_loss if isinstance(stop_loss, float) else float(stop_loss.item()), + 'avg_step_time': step_time, + 'avg_loader_time': loader_time} + keep_avg.update_values(update_train_values) # Plot Training Iter Stats # reduce TB load @@ -233,7 +254,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() - gt_spec = linear_input[0].data.cpu().numpy() if c.model in ["Tacotron", "TacotronGST"] else mel_input[0].data.cpu().numpy() + gt_spec = linear_input[0].data.cpu().numpy() if c.model in [ + "Tacotron", "TacotronGST"] else mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { @@ -253,35 +275,28 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, c.audio["sample_rate"]) end_time = time.time() - avg_postnet_loss /= (num_iter + 1) - avg_decoder_loss /= (num_iter + 1) - avg_stop_loss /= (num_iter + 1) - avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss - avg_step_time /= (num_iter + 1) - avg_loader_time /= (num_iter + 1) - # print epoch stats print( " | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} " - "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, avg_total_loss, - avg_postnet_loss, avg_decoder_loss, - avg_stop_loss, epoch_time, avg_step_time, - avg_loader_time), + "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, keep_avg['avg_postnet_loss'], keep_avg['avg_decoder_loss'], + keep_avg['avg_stop_loss'], keep_avg['avg_align_score'], + epoch_time, keep_avg['avg_step_time'], keep_avg['avg_loader_time']), flush=True) # Plot Epoch Stats if args.rank == 0: # Plot Training Epoch Stats - epoch_stats = {"loss_postnet": avg_postnet_loss, - "loss_decoder": avg_decoder_loss, - "stop_loss": avg_stop_loss, + epoch_stats = {"loss_postnet": keep_avg['avg_postnet_loss'], + "loss_decoder": keep_avg['avg_decoder_loss'], + "stop_loss": keep_avg['avg_stop_loss'], + "alignment_score": keep_avg['avg_align_score'], "epoch_time": epoch_time} tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, global_step) - return avg_postnet_loss, global_step + return keep_avg['avg_postnet_loss'], global_step def evaluate(model, criterion, criterion_st, ap, global_step, epoch): @@ -290,9 +305,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): speaker_mapping = load_speaker_mapping(OUT_PATH) model.eval() epoch_time = 0 - avg_postnet_loss = 0 - avg_decoder_loss = 0 - avg_stop_loss = 0 + eval_values_dict = {'avg_postnet_loss' : 0, + 'avg_decoder_loss' : 0, + 'avg_stop_loss' : 0, + 'avg_align_score': 0} + keep_avg = KeepAverage() + keep_avg.add_values(eval_values_dict) print("\n > Validation") if c.test_sentences_file is None: test_sentences = [ @@ -313,7 +331,8 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): text_input = data[0] text_lengths = data[1] speaker_names = data[2] - linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None + linear_input = data[3] if c.model in [ + "Tacotron", "TacotronGST"] else None mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] @@ -329,14 +348,16 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze( + 2).float().squeeze(2) # dispatch data to GPU if use_cuda: text_input = text_input.cuda() mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() - linear_input = linear_input.cuda() if c.model in ["Tacotron", "TacotronGST"] else None + linear_input = linear_input.cuda() if c.model in [ + "Tacotron", "TacotronGST"] else None stop_targets = stop_targets.cuda() if speaker_ids is not None: speaker_ids = speaker_ids.cuda() @@ -347,13 +368,17 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): speaker_ids=speaker_ids) # loss computation - stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) + stop_loss = criterion_st( + stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) if c.loss_masking: - decoder_loss = criterion(decoder_output, mel_input, mel_lengths) + decoder_loss = criterion( + decoder_output, mel_input, mel_lengths) if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input, mel_lengths) + postnet_loss = criterion( + postnet_output, linear_input, mel_lengths) else: - postnet_loss = criterion(postnet_output, mel_input, mel_lengths) + postnet_loss = criterion( + postnet_output, mel_input, mel_lengths) else: decoder_loss = criterion(decoder_output, mel_input) if c.model in ["Tacotron", "TacotronGST"]: @@ -365,14 +390,9 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): step_time = time.time() - start_time epoch_time += step_time - if num_iter % c.print_step == 0: - print( - " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} DecoderLoss:{:.5f} " - "StopLoss: {:.5f} ".format(loss.item(), - postnet_loss.item(), - decoder_loss.item(), - stop_loss.item()), - flush=True) + # compute alignment score + align_score = alignment_diagonal_score(alignments) + keep_avg.update_value('avg_align_score', align_score) # aggregate losses from processes if num_gpus > 1: @@ -381,15 +401,26 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): if c.stopnet: stop_loss = reduce_tensor(stop_loss.data, num_gpus) - avg_postnet_loss += float(postnet_loss.item()) - avg_decoder_loss += float(decoder_loss.item()) - avg_stop_loss += stop_loss.item() + keep_avg.update_values({'avg_postnet_loss' : float(postnet_loss.item()), + 'avg_decoder_loss' : float(decoder_loss.item()), + 'avg_stop_loss' : float(stop_loss.item())}) + + if num_iter % c.print_step == 0: + print( + " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} " + "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format(loss.item(), + postnet_loss.item(), keep_avg['avg_postnet_loss'], + decoder_loss.item(), keep_avg['avg_decoder_loss'], + stop_loss.item(), keep_avg['avg_stop_loss'], + align_score.item(), keep_avg['avg_align_score']), + flush=True) if args.rank == 0: # Diagnostic visualizations idx = np.random.randint(mel_input.shape[0]) const_spec = postnet_output[idx].data.cpu().numpy() - gt_spec = linear_input[idx].data.cpu().numpy() if c.model in ["Tacotron", "TacotronGST"] else mel_input[idx].data.cpu().numpy() + gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ + "Tacotron", "TacotronGST"] else mel_input[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy() eval_figures = { @@ -404,17 +435,13 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): eval_audio = ap.inv_spectrogram(const_spec.T) else: eval_audio = ap.inv_mel_spectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) - - # compute average losses - avg_postnet_loss /= (num_iter + 1) - avg_decoder_loss /= (num_iter + 1) - avg_stop_loss /= (num_iter + 1) + tb_logger.tb_eval_audios( + global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats - epoch_stats = {"loss_postnet": avg_postnet_loss, - "loss_decoder": avg_decoder_loss, - "stop_loss": avg_stop_loss} + epoch_stats = {"loss_postnet": keep_avg['avg_postnet_loss'], + "loss_decoder": keep_avg['avg_decoder_loss'], + "stop_loss": keep_avg['avg_stop_loss']} tb_logger.tb_eval_stats(global_step, epoch_stats) if args.rank == 0 and epoch > c.test_delay_epochs: @@ -436,18 +463,21 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format(idx)] = plot_spectrogram(postnet_output, ap) - test_figures['{}-alignment'.format(idx)] = plot_alignment(alignment) + test_figures['{}-prediction'.format(idx) + ] = plot_spectrogram(postnet_output, ap) + test_figures['{}-alignment'.format(idx) + ] = plot_alignment(alignment) except: print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, c.audio['sample_rate']) + tb_logger.tb_test_audios( + global_step, test_audios, c.audio['sample_rate']) tb_logger.tb_test_figures(global_step, test_figures) - return avg_postnet_loss + return keep_avg['avg_postnet_loss'] -#FIXME: move args definition/parsing inside of main? -def main(args): #pylint: disable=redefined-outer-name +# FIXME: move args definition/parsing inside of main? +def main(args): # pylint: disable=redefined-outer-name # Audio processor ap = AudioProcessor(**c.audio) @@ -488,9 +518,11 @@ def main(args): #pylint: disable=redefined-outer-name optimizer_st = None if c.loss_masking: - criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"] else MSELossMasked() + criterion = L1LossMasked() if c.model in [ + "Tacotron", "TacotronGST"] else MSELossMasked() else: - criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"] else nn.MSELoss() + criterion = nn.L1Loss() if c.model in [ + "Tacotron", "TacotronGST"] else nn.MSELoss() criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None if args.restore_path: @@ -552,7 +584,8 @@ def main(args): #pylint: disable=redefined-outer-name train_loss, global_step = train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, ap, global_step, epoch) - val_loss = evaluate(model, criterion, criterion_st, ap, global_step, epoch) + val_loss = evaluate(model, criterion, criterion_st, + ap, global_step, epoch) print( " | > Training Loss: {:.5f} Validation Loss: {:.5f}".format( train_loss, val_loss), @@ -635,7 +668,8 @@ if __name__ == '__main__': if args.restore_path: new_fields["restore_path"] = args.restore_path new_fields["github_branch"] = get_git_branch() - copy_config_file(args.config_path, os.path.join(OUT_PATH, 'config.json'), new_fields) + copy_config_file(args.config_path, os.path.join( + OUT_PATH, 'config.json'), new_fields) os.chmod(AUDIO_PATH, 0o775) os.chmod(OUT_PATH, 0o775) @@ -650,8 +684,8 @@ if __name__ == '__main__': try: sys.exit(0) except SystemExit: - os._exit(0) #pylint: disable=protected-access - except Exception: #pylint: disable=broad-except + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except remove_experiment_folder(OUT_PATH) traceback.print_exc() sys.exit(1) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 1c16834a..d72ffdd5 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -313,4 +313,35 @@ def gradual_training_scheduler(global_step, config): for values in config.gradual_training: if global_step >= values[0]: new_values = values - return new_values[1], new_values[2] \ No newline at end of file + return new_values[1], new_values[2] + + +class KeepAverage(): + def __init__(self): + self.avg_values = {} + self.iters = {} + + def __getitem__(self, key): + return self.avg_values[key] + + def add_value(self, name, init_val=0, init_iter=0): + self.avg_values[name] = init_val + self.iters[name] = init_iter + + def update_value(self, name, value, weighted_avg=False): + if weighted_avg: + self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value + self.iters[name] += 1 + else: + self.avg_values[name] = self.avg_values[name] * self.iters[name] + value + self.iters[name] += 1 + self.avg_values[name] /= self.iters[name] + + def add_values(self, name_dict): + for key, value in name_dict.items(): + self.add_value(key, init_val=value) + + def update_values(self, value_dict): + for key, value in value_dict.items(): + self.update_value(key, value) + diff --git a/utils/measures.py b/utils/measures.py new file mode 100644 index 00000000..21652cf0 --- /dev/null +++ b/utils/measures.py @@ -0,0 +1,19 @@ +import torch +import numpy as np + + +def alignment_diagonal_score(alignments): + """ + Compute how diagonal alignment predictions are. It is useful + to measure the alignment consistency of a model + Args: + alignments (torch.Tensor): batch of alignments. + Shape: + alignments : batch x decoder_steps x encoder_steps + """ + return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0) + + + + +