close figures in training to prevent OOM

This commit is contained in:
erogol 2020-08-03 11:40:43 +02:00
parent 180d443765
commit b016c655ea
2 changed files with 20 additions and 13 deletions

View File

@ -255,13 +255,13 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
figures = { figures = {
"prediction": plot_spectrogram(const_spec, ap), "prediction": plot_spectrogram(const_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img), "alignment": plot_alignment(align_img, output_fig=False),
} }
if c.bidirectional_decoder or c.double_decoder_consistency: if c.bidirectional_decoder or c.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy()) figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
@ -369,9 +369,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
eval_figures = { eval_figures = {
"prediction": plot_spectrogram(const_spec, ap), "prediction": plot_spectrogram(const_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img) "alignment": plot_alignment(align_img, output_fig=False)
} }
# Sample audio # Sample audio
@ -386,7 +386,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
if c.bidirectional_decoder or c.double_decoder_consistency: if c.bidirectional_decoder or c.double_decoder_consistency:
align_b_img = alignments_backward[idx].data.cpu().numpy() align_b_img = alignments_backward[idx].data.cpu().numpy()
eval_figures['alignment2'] = plot_alignment(align_b_img) eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False)
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
tb_logger.tb_eval_figures(global_step, eval_figures) tb_logger.tb_eval_figures(global_step, eval_figures)
@ -431,9 +431,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
ap.save_wav(wav, file_path) ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram( test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
postnet_output, ap) postnet_output, ap, output_fig=False)
test_figures['{}-alignment'.format(idx)] = plot_alignment( test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment) alignment, output_fig=False)
except: except:
print(" !! Error creating Test Sentence -", idx) print(" !! Error creating Test Sentence -", idx)
traceback.print_exc() traceback.print_exc()

View File

@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None): def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False):
if isinstance(alignment, torch.Tensor): if isinstance(alignment, torch.Tensor):
alignment_ = alignment.detach().cpu().numpy().squeeze() alignment_ = alignment.detach().cpu().numpy().squeeze()
else: else:
@ -24,10 +24,12 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None):
plt.tight_layout() plt.tight_layout()
if title is not None: if title is not None:
plt.title(title) plt.title(title)
if not output_fig:
plt.close()
return fig return fig
def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10)): def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
if isinstance(spectrogram, torch.Tensor): if isinstance(spectrogram, torch.Tensor):
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
else: else:
@ -38,10 +40,12 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10)):
plt.imshow(spectrogram_, aspect="auto", origin="lower") plt.imshow(spectrogram_, aspect="auto", origin="lower")
plt.colorbar() plt.colorbar()
plt.tight_layout() plt.tight_layout()
if not output_fig:
plt.close()
return fig return fig
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24)): def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24), output_fig=False):
if decoder_output is not None: if decoder_output is not None:
num_plot = 4 num_plot = 4
else: else:
@ -91,3 +95,6 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
print(output_path) print(output_path)
fig.savefig(output_path) fig.savefig(output_path)
plt.close() plt.close()
if not output_fig:
plt.close()