visual.py update

This commit is contained in:
Eren Golge 2019-09-05 16:48:36 +02:00
parent 8ff17dfab1
commit 0bb8d780e8
1 changed files with 17 additions and 6 deletions

View File

@ -1,3 +1,4 @@
import torch
import librosa import librosa
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
@ -5,10 +6,14 @@ import matplotlib.pyplot as plt
from TTS.utils.text import phoneme_to_sequence, sequence_to_phoneme from TTS.utils.text import phoneme_to_sequence, sequence_to_phoneme
def plot_alignment(alignment, info=None): def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None):
fig, ax = plt.subplots(figsize=(16, 10)) if isinstance(alignment, torch.Tensor):
alignment_ = alignment.detach().cpu().numpy().squeeze()
else:
alignment_ = alignment
fig, ax = plt.subplots(figsize=fig_size)
im = ax.imshow( im = ax.imshow(
alignment.T, aspect='auto', origin='lower', interpolation='none') alignment_.T, aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax) fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep' xlabel = 'Decoder timestep'
if info is not None: if info is not None:
@ -17,12 +22,18 @@ def plot_alignment(alignment, info=None):
plt.ylabel('Encoder timestep') plt.ylabel('Encoder timestep')
# plt.yticks(range(len(text)), list(text)) # plt.yticks(range(len(text)), list(text))
plt.tight_layout() plt.tight_layout()
if title is not None:
plt.title(title)
return fig return fig
def plot_spectrogram(linear_output, audio): def plot_spectrogram(linear_output, audio, fig_size=(16, 10)):
spectrogram = audio._denormalize(linear_output) if isinstance(linear_output, torch.Tensor):
fig = plt.figure(figsize=(16, 10)) linear_output_ = linear_output.detach().cpu().numpy().squeeze()
else:
linear_output_ = linear_output
spectrogram = audio._denormalize(linear_output_)
fig = plt.figure(figsize=fig_size)
plt.imshow(spectrogram.T, aspect="auto", origin="lower") plt.imshow(spectrogram.T, aspect="auto", origin="lower")
plt.colorbar() plt.colorbar()
plt.tight_layout() plt.tight_layout()