mirror of https://github.com/coqui-ai/TTS.git
close figures in training to prevent OOM
This commit is contained in:
parent
180d443765
commit
b016c655ea
|
@ -255,13 +255,13 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
figures = {
|
figures = {
|
||||||
"prediction": plot_spectrogram(const_spec, ap),
|
"prediction": plot_spectrogram(const_spec, ap, output_fig=False),
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
"alignment": plot_alignment(align_img),
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||||
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy())
|
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||||
|
|
||||||
tb_logger.tb_train_figures(global_step, figures)
|
tb_logger.tb_train_figures(global_step, figures)
|
||||||
|
|
||||||
|
@ -369,9 +369,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
|
||||||
eval_figures = {
|
eval_figures = {
|
||||||
"prediction": plot_spectrogram(const_spec, ap),
|
"prediction": plot_spectrogram(const_spec, ap, output_fig=False),
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
"alignment": plot_alignment(align_img)
|
"alignment": plot_alignment(align_img, output_fig=False)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sample audio
|
# Sample audio
|
||||||
|
@ -386,7 +386,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
|
|
||||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||||
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
||||||
eval_figures['alignment2'] = plot_alignment(align_b_img)
|
eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False)
|
||||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||||
|
|
||||||
|
@ -431,9 +431,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
||||||
ap.save_wav(wav, file_path)
|
ap.save_wav(wav, file_path)
|
||||||
test_audios['{}-audio'.format(idx)] = wav
|
test_audios['{}-audio'.format(idx)] = wav
|
||||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||||
postnet_output, ap)
|
postnet_output, ap, output_fig=False)
|
||||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||||
alignment)
|
alignment, output_fig=False)
|
||||||
except:
|
except:
|
||||||
print(" !! Error creating Test Sentence -", idx)
|
print(" !! Error creating Test Sentence -", idx)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
|
||||||
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
|
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
|
||||||
|
|
||||||
|
|
||||||
def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None):
|
def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False):
|
||||||
if isinstance(alignment, torch.Tensor):
|
if isinstance(alignment, torch.Tensor):
|
||||||
alignment_ = alignment.detach().cpu().numpy().squeeze()
|
alignment_ = alignment.detach().cpu().numpy().squeeze()
|
||||||
else:
|
else:
|
||||||
|
@ -24,10 +24,12 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None):
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
if title is not None:
|
if title is not None:
|
||||||
plt.title(title)
|
plt.title(title)
|
||||||
|
if not output_fig:
|
||||||
|
plt.close()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10)):
|
def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
|
||||||
if isinstance(spectrogram, torch.Tensor):
|
if isinstance(spectrogram, torch.Tensor):
|
||||||
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
|
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
|
||||||
else:
|
else:
|
||||||
|
@ -38,10 +40,12 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10)):
|
||||||
plt.imshow(spectrogram_, aspect="auto", origin="lower")
|
plt.imshow(spectrogram_, aspect="auto", origin="lower")
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
if not output_fig:
|
||||||
|
plt.close()
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24)):
|
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24), output_fig=False):
|
||||||
if decoder_output is not None:
|
if decoder_output is not None:
|
||||||
num_plot = 4
|
num_plot = 4
|
||||||
else:
|
else:
|
||||||
|
@ -91,3 +95,6 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
|
||||||
print(output_path)
|
print(output_path)
|
||||||
fig.savefig(output_path)
|
fig.savefig(output_path)
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
if not output_fig:
|
||||||
|
plt.close()
|
||||||
|
|
Loading…
Reference in New Issue