diff --git a/train.py b/train.py index f0e3b68a..cf073956 100644 --- a/train.py +++ b/train.py @@ -114,7 +114,7 @@ def format_data(data): return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length -def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, +def train(model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch): data_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=(epoch == 0)) @@ -132,6 +132,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, if c.bidirectional_decoder: train_values['avg_decoder_b_loss'] = 0 # decoder backward loss train_values['avg_decoder_c_loss'] = 0 # decoder consistency loss + if c.ga_alpha > 0: + train_values['avg_ga_loss'] = 0 # guidede attention loss keep_avg = KeepAverage() keep_avg.add_values(train_values) print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) @@ -164,39 +166,27 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, else: decoder_output, postnet_output, alignments, stop_tokens = model( text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + decoder_backward_output = None - # loss computation - stop_loss = criterion_st(stop_tokens, - stop_targets, mel_lengths) if c.stopnet else torch.zeros(1) - if c.loss_masking: - decoder_loss = criterion(decoder_output, mel_input, mel_lengths) - if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input, - mel_lengths) - else: - postnet_loss = criterion(postnet_output, mel_input, - mel_lengths) + # set the alignment lengths wrt reduction factor for guided attention + if mel_lengths.max() % model.decoder.r != 0: + alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r else: - decoder_loss = criterion(decoder_output, mel_input) - if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input) - else: - postnet_loss = criterion(postnet_output, mel_input) - loss = decoder_loss + postnet_loss - if not c.separate_stopnet and c.stopnet: - loss += stop_loss + alignment_lengths = mel_lengths // model.decoder.r - # backward decoder + # compute loss + loss_dict = criterion(postnet_output, decoder_output, mel_input, + linear_input, stop_tokens, stop_targets, + mel_lengths, decoder_backward_output, + alignments, alignment_lengths, text_lengths) if c.bidirectional_decoder: - if c.loss_masking: - decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths) - else: - decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input) - decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output) - loss += decoder_backward_loss + decoder_c_loss - keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()}) + keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(), + 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()}) + if c.ga_alpha > 0: + keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()}) - loss.backward() + # backward pass + loss_dict['loss'].backward() optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, grad_flag = check_update(model, c.grad_clip, ignore_stopnet=True) optimizer.step() @@ -207,7 +197,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # backpass and check the grad norm for stop loss if c.separate_stopnet: - stop_loss.backward() + loss_dict['stopnet_loss'].backward() optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() @@ -220,36 +210,31 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, if global_step % c.print_step == 0: print( " | > Step:{}/{} GlobalStep:{} PostnetLoss:{:.5f} " - "DecoderLoss:{:.5f} StopLoss:{:.5f} AlignScore:{:.4f} GradNorm:{:.5f} " + "DecoderLoss:{:.5f} StopLoss:{:.5f} GALoss:{:.5f} GradNorm:{:.5f} " "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} " "LoaderTime:{:.2f} LR:{:.6f}".format( - num_iter, batch_n_iter, global_step, postnet_loss.item(), - decoder_loss.item(), stop_loss.item(), align_score, - grad_norm, grad_norm_st, avg_text_length, avg_spec_length, - step_time, loader_time, current_lr), + num_iter, batch_n_iter, global_step, loss_dict['postnet_loss'].item(), + loss_dict['decoder_loss'].item(), loss_dict['stopnet_loss'].item(), + loss_dict['ga_loss'].item(), grad_norm, grad_norm_st, avg_text_length, + avg_spec_length, step_time, loader_time, current_lr), flush=True) # aggregate losses from processes if num_gpus > 1: - postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) - decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) - loss = reduce_tensor(loss.data, num_gpus) - stop_loss = reduce_tensor(stop_loss.data, - num_gpus) if c.stopnet else stop_loss + loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) + loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) + loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, + num_gpus) if c.stopnet else loss_dict['stopnet_loss'] if args.rank == 0: update_train_values = { - 'avg_postnet_loss': - float(postnet_loss.item()), - 'avg_decoder_loss': - float(decoder_loss.item()), - 'avg_stop_loss': - stop_loss - if isinstance(stop_loss, float) else float(stop_loss.item()), - 'avg_step_time': - step_time, - 'avg_loader_time': - loader_time + 'avg_postnet_loss': float(loss_dict['postnet_loss'].item()), + 'avg_decoder_loss': float(loss_dict['decoder_loss'].item()), + 'avg_stop_loss': loss_dict['stopnet_loss'].item() + if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()), + 'avg_step_time': step_time, + 'avg_loader_time': loader_time } keep_avg.update_values(update_train_values) @@ -257,8 +242,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # reduce TB load if global_step % 10 == 0: iter_stats = { - "loss_posnet": postnet_loss.item(), - "loss_decoder": decoder_loss.item(), + "loss_posnet": loss_dict['postnet_loss'].item(), + "loss_decoder": loss_dict['decoder_loss'].item(), "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, @@ -270,7 +255,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, if c.checkpoint: # save model save_checkpoint(model, optimizer, optimizer_st, - postnet_loss.item(), OUT_PATH, global_step, + loss_dict['postnet_loss'].item(), OUT_PATH, global_step, epoch) # Diagnostic visualizations @@ -295,7 +280,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, if c.model in ["Tacotron", "TacotronGST"]: train_audio = ap.inv_spectrogram(const_spec.T) else: - train_audio = ap.inv_mel_spectrogram(const_spec.T) + train_audio = ap.inv_melspectrogram(const_spec.T) tb_logger.tb_train_audios(global_step, {'TrainAudio': train_audio}, c.audio["sample_rate"]) @@ -304,11 +289,11 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # print epoch stats print(" | > EPOCH END -- GlobalStep:{} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} " - "AvgStopLoss:{:.5f} AvgAlignScore:{:3f} EpochTime:{:.2f} " + "AvgStopLoss:{:.5f} AvgGALoss:{:3f} EpochTime:{:.2f} " "AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format( global_step, keep_avg['avg_postnet_loss'], keep_avg['avg_decoder_loss'], keep_avg['avg_stop_loss'], - keep_avg['avg_align_score'], epoch_time, + keep_avg['avg_ga_loss'], epoch_time, keep_avg['avg_step_time'], keep_avg['avg_loader_time']), flush=True) # Plot Epoch Stats @@ -321,6 +306,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, "alignment_score": keep_avg['avg_align_score'], "epoch_time": epoch_time } + if c.ga_alpha > 0: + epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss'] tb_logger.tb_train_epoch_stats(global_step, epoch_stats) if c.tb_model_param_stats: tb_logger.tb_model_weights(model, global_step) @@ -328,7 +315,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, @torch.no_grad() -def evaluate(model, criterion, criterion_st, ap, global_step, epoch): +def evaluate(model, criterion, ap, global_step, epoch): data_loader = setup_loader(ap, model.decoder.r, is_val=True) if c.use_speaker_embedding: speaker_mapping = load_speaker_mapping(OUT_PATH) @@ -343,6 +330,8 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): if c.bidirectional_decoder: eval_values_dict['avg_decoder_b_loss'] = 0 # decoder backward loss eval_values_dict['avg_decoder_c_loss'] = 0 # decoder consistency loss + if c.ga_alpha > 0: + eval_values_dict['avg_ga_loss'] = 0 # guidede attention loss keep_avg = KeepAverage() keep_avg.add_values(eval_values_dict) print("\n > Validation") @@ -362,37 +351,26 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): else: decoder_output, postnet_output, alignments, stop_tokens = model( text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + decoder_backward_output = None - # loss computation - stop_loss = criterion_st( - stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1) - if c.loss_masking: - decoder_loss = criterion(decoder_output, mel_input, - mel_lengths) - if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input, - mel_lengths) - else: - postnet_loss = criterion(postnet_output, mel_input, - mel_lengths) + # set the alignment lengths wrt reduction factor for guided attention + if mel_lengths.max() % model.decoder.r != 0: + alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r else: - decoder_loss = criterion(decoder_output, mel_input) - if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input) - else: - postnet_loss = criterion(postnet_output, mel_input) - loss = decoder_loss + postnet_loss + stop_loss + alignment_lengths = mel_lengths // model.decoder.r - # backward decoder loss + # compute loss + loss_dict = criterion(postnet_output, decoder_output, mel_input, + linear_input, stop_tokens, stop_targets, + mel_lengths, decoder_backward_output, + alignments, alignment_lengths, text_lengths) if c.bidirectional_decoder: - if c.loss_masking: - decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths) - else: - decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input) - decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output) - loss += decoder_backward_loss + decoder_c_loss - keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()}) + keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(), + 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()}) + if c.ga_alpha > 0: + keep_avg.update_values({'avg_ga_loss': loss_dict['ga_loss'].item()}) + # step time step_time = time.time() - start_time epoch_time += step_time @@ -409,23 +387,27 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): keep_avg.update_values({ 'avg_postnet_loss': - float(postnet_loss.item()), + float(loss_dict['postnet_loss'].item()), 'avg_decoder_loss': - float(decoder_loss.item()), + float(loss_dict['decoder_loss'].item()), 'avg_stop_loss': - float(stop_loss.item()), + float(loss_dict['stopnet_loss'].item()), }) if num_iter % c.print_step == 0: print( " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} " - "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}" - .format(loss.item(), postnet_loss.item(), + "StopLoss: {:.5f} - {:.5f} GALoss: {:.5f} - {:.5f} AlignScore: {:.4f} - {:.4f}" + .format(loss_dict['loss'].item(), + loss_dict['postnet_loss'].item(), keep_avg['avg_postnet_loss'], - decoder_loss.item(), - keep_avg['avg_decoder_loss'], stop_loss.item(), - keep_avg['avg_stop_loss'], align_score, - keep_avg['avg_align_score']), + loss_dict['decoder_loss'].item(), + keep_avg['avg_decoder_loss'], + loss_dict['stopnet_loss'].item(), + keep_avg['avg_stop_loss'], + loss_dict['ga_loss'].item(), + keep_avg['avg_ga_loss'], + align_score, keep_avg['avg_align_score']), flush=True) if args.rank == 0: @@ -447,7 +429,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): if c.model in ["Tacotron", "TacotronGST"]: eval_audio = ap.inv_spectrogram(const_spec.T) else: - eval_audio = ap.inv_mel_spectrogram(const_spec.T) + eval_audio = ap.inv_melspectrogram(const_spec.T) tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) @@ -456,13 +438,15 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): "loss_postnet": keep_avg['avg_postnet_loss'], "loss_decoder": keep_avg['avg_decoder_loss'], "stop_loss": keep_avg['avg_stop_loss'], - "alignment_score": keep_avg['avg_align_score'] + "alignment_score": keep_avg['avg_align_score'], } if c.bidirectional_decoder: epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss'] align_b_img = alignments_backward[idx].data.cpu().numpy() eval_figures['alignment_backward'] = plot_alignment(align_b_img) + if c.ga_alpha > 0: + epoch_stats['guided_attention_loss'] = keep_avg['avg_ga_loss'] tb_logger.tb_eval_stats(global_step, epoch_stats) tb_logger.tb_eval_figures(global_step, eval_figures) @@ -486,7 +470,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): style_wav = c.get("style_wav_for_test") for idx, test_sentence in enumerate(test_sentences): try: - wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis( + wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis( model, test_sentence, c, @@ -565,14 +549,8 @@ def main(args): # pylint: disable=redefined-outer-name else: optimizer_st = None - if c.loss_masking: - criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron", "TacotronGST" - ] else MSELossMasked(c.seq_len_norm) - else: - criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST" - ] else nn.MSELoss() - criterion_st = BCELossMasked( - pos_weight=torch.tensor(10)) if c.stopnet else None + # setup criterion + criterion = TacotronLoss(c, stopnet_pos_weight=10.0, ga_sigma=0.4) if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') @@ -600,8 +578,6 @@ def main(args): # pylint: disable=redefined-outer-name if use_cuda: model.cuda() criterion.cuda() - if criterion_st: - criterion_st.cuda() # DISTRUBUTED if num_gpus > 1: @@ -631,11 +607,10 @@ def main(args): # pylint: disable=redefined-outer-name model.decoder_backward.set_r(r) print(" > Number of outputs per iteration:", model.decoder.r) - train_loss, global_step = train(model, criterion, criterion_st, - optimizer, optimizer_st, scheduler, ap, + train_loss, global_step = train(model, criterion, optimizer, + optimizer_st, scheduler, ap, global_step, epoch) - val_loss = evaluate(model, criterion, criterion_st, ap, global_step, - epoch) + val_loss = evaluate(model, criterion, ap, global_step, epoch) print(" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format( train_loss, val_loss), flush=True) diff --git a/utils/audio.py b/utils/audio.py index 771e6a43..31c39eb2 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -31,7 +31,7 @@ class AudioProcessor(object): **_): print(" > Setting up Audio Processor...") - + # setup class attributed self.sample_rate = sample_rate self.num_mels = num_mels self.min_level_db = min_level_db or 0