glow tts training and inference fixes

This commit is contained in:
erogol 2020-08-05 12:59:40 +02:00
parent d5c6d60884
commit 673ba74a80
3 changed files with 51 additions and 17 deletions

View File

@ -18,7 +18,7 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
if use_cuda:
model.cuda()
# set model stepsize
if 'r' in state.keys():
if hasattr(model.decoder, 'r'):
model.decoder.set_r(state['r'])
return model, state

View File

@ -107,9 +107,9 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
postnet_output = postnet_output[0].data.cpu().numpy()
decoder_output = decoder_output[0].data.cpu().numpy()
decoder_output = None if decoder_output is None else decoder_output[0].data.cpu().numpy()
alignment = alignments[0].cpu().data.numpy()
stop_tokens = stop_tokens[0].cpu().numpy()
stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy()
return postnet_output, decoder_output, alignment, stop_tokens

View File

@ -6,14 +6,20 @@ import matplotlib.pyplot as plt
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False):
def plot_alignment(alignment,
info=None,
fig_size=(16, 10),
title=None,
output_fig=False):
if isinstance(alignment, torch.Tensor):
alignment_ = alignment.detach().cpu().numpy().squeeze()
else:
alignment_ = alignment
fig, ax = plt.subplots(figsize=fig_size)
im = ax.imshow(
alignment_.T, aspect='auto', origin='lower', interpolation='none')
im = ax.imshow(alignment_.T,
aspect='auto',
origin='lower',
interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
if info is not None:
@ -29,7 +35,10 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_f
return fig
def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
def plot_spectrogram(spectrogram,
ap=None,
fig_size=(16, 10),
output_fig=False):
if isinstance(spectrogram, torch.Tensor):
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
else:
@ -45,7 +54,17 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
return fig
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24), output_fig=False):
def visualize(alignment,
postnet_output,
text,
hop_length,
CONFIG,
stop_tokens=None,
decoder_output=None,
output_path=None,
figsize=(8, 24),
output_fig=False):
if decoder_output is not None:
num_plot = 4
else:
@ -60,18 +79,30 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
# compute phoneme representation and back
if CONFIG.use_phonemes:
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
seq = phoneme_to_sequence(
text, [CONFIG.text_cleaner],
CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
text = sequence_to_phoneme(
seq,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
print(text)
plt.yticks(range(len(text)), list(text))
plt.colorbar()
# plot stopnet predictions
plt.subplot(num_plot, 1, 2)
plt.plot(range(len(stop_tokens)), list(stop_tokens))
if stop_tokens is not None:
# plot stopnet predictions
plt.subplot(num_plot, 1, 2)
plt.plot(range(len(stop_tokens)), list(stop_tokens))
# plot postnet spectrogram
plt.subplot(num_plot, 1, 3)
librosa.display.specshow(postnet_output.T, sr=CONFIG.audio['sample_rate'],
hop_length=hop_length, x_axis="time", y_axis="linear",
librosa.display.specshow(postnet_output.T,
sr=CONFIG.audio['sample_rate'],
hop_length=hop_length,
x_axis="time",
y_axis="linear",
fmin=CONFIG.audio['mel_fmin'],
fmax=CONFIG.audio['mel_fmax'])
@ -82,8 +113,11 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
if decoder_output is not None:
plt.subplot(num_plot, 1, 4)
librosa.display.specshow(decoder_output.T, sr=CONFIG.audio['sample_rate'],
hop_length=hop_length, x_axis="time", y_axis="linear",
librosa.display.specshow(decoder_output.T,
sr=CONFIG.audio['sample_rate'],
hop_length=hop_length,
x_axis="time",
y_axis="linear",
fmin=CONFIG.audio['mel_fmin'],
fmax=CONFIG.audio['mel_fmax'])
plt.xlabel("Time", fontsize=label_fontsize)