diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index f7a69eb4..4afa5d8f 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -255,13 +255,13 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, align_img = alignments[0].data.cpu().numpy() figures = { - "prediction": plot_spectrogram(const_spec, ap), - "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img), + "prediction": plot_spectrogram(const_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), } if c.bidirectional_decoder or c.double_decoder_consistency: - figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy()) + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) tb_logger.tb_train_figures(global_step, figures) @@ -369,9 +369,9 @@ def evaluate(model, criterion, ap, global_step, epoch): align_img = alignments[idx].data.cpu().numpy() eval_figures = { - "prediction": plot_spectrogram(const_spec, ap), - "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img) + "prediction": plot_spectrogram(const_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False) } # Sample audio @@ -386,7 +386,7 @@ def evaluate(model, criterion, ap, global_step, epoch): if c.bidirectional_decoder or c.double_decoder_consistency: align_b_img = alignments_backward[idx].data.cpu().numpy() - eval_figures['alignment2'] = plot_alignment(align_b_img) + eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_figures(global_step, eval_figures) @@ -431,9 +431,9 @@ def evaluate(model, criterion, ap, global_step, epoch): ap.save_wav(wav, file_path) test_audios['{}-audio'.format(idx)] = wav test_figures['{}-prediction'.format(idx)] = plot_spectrogram( - postnet_output, ap) + postnet_output, ap, output_fig=False) test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment) + alignment, output_fig=False) except: print(" !! Error creating Test Sentence -", idx) traceback.print_exc() diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index d823aa91..63ccd077 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme -def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None): +def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False): if isinstance(alignment, torch.Tensor): alignment_ = alignment.detach().cpu().numpy().squeeze() else: @@ -24,10 +24,12 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None): plt.tight_layout() if title is not None: plt.title(title) + if not output_fig: + plt.close() return fig -def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10)): +def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): if isinstance(spectrogram, torch.Tensor): spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T else: @@ -38,10 +40,12 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10)): plt.imshow(spectrogram_, aspect="auto", origin="lower") plt.colorbar() plt.tight_layout() + if not output_fig: + plt.close() return fig -def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24)): +def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24), output_fig=False): if decoder_output is not None: num_plot = 4 else: @@ -91,3 +95,6 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, print(output_path) fig.savefig(output_path) plt.close() + + if not output_fig: + plt.close()