mirror of https://github.com/coqui-ai/TTS.git
glow tts training and inference fixes
This commit is contained in:
parent
d5c6d60884
commit
673ba74a80
|
@ -18,7 +18,7 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
|||
if use_cuda:
|
||||
model.cuda()
|
||||
# set model stepsize
|
||||
if 'r' in state.keys():
|
||||
if hasattr(model.decoder, 'r'):
|
||||
model.decoder.set_r(state['r'])
|
||||
return model, state
|
||||
|
||||
|
|
|
@ -107,9 +107,9 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me
|
|||
|
||||
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||
decoder_output = None if decoder_output is None else decoder_output[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
stop_tokens = stop_tokens[0].cpu().numpy()
|
||||
stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy()
|
||||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
||||
|
||||
|
|
|
@ -6,14 +6,20 @@ import matplotlib.pyplot as plt
|
|||
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
|
||||
|
||||
|
||||
def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False):
|
||||
def plot_alignment(alignment,
|
||||
info=None,
|
||||
fig_size=(16, 10),
|
||||
title=None,
|
||||
output_fig=False):
|
||||
if isinstance(alignment, torch.Tensor):
|
||||
alignment_ = alignment.detach().cpu().numpy().squeeze()
|
||||
else:
|
||||
alignment_ = alignment
|
||||
fig, ax = plt.subplots(figsize=fig_size)
|
||||
im = ax.imshow(
|
||||
alignment_.T, aspect='auto', origin='lower', interpolation='none')
|
||||
im = ax.imshow(alignment_.T,
|
||||
aspect='auto',
|
||||
origin='lower',
|
||||
interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
if info is not None:
|
||||
|
@ -29,7 +35,10 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_f
|
|||
return fig
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
|
||||
def plot_spectrogram(spectrogram,
|
||||
ap=None,
|
||||
fig_size=(16, 10),
|
||||
output_fig=False):
|
||||
if isinstance(spectrogram, torch.Tensor):
|
||||
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
|
||||
else:
|
||||
|
@ -45,7 +54,17 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
|
|||
return fig
|
||||
|
||||
|
||||
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24), output_fig=False):
|
||||
def visualize(alignment,
|
||||
postnet_output,
|
||||
text,
|
||||
hop_length,
|
||||
CONFIG,
|
||||
stop_tokens=None,
|
||||
decoder_output=None,
|
||||
output_path=None,
|
||||
figsize=(8, 24),
|
||||
output_fig=False):
|
||||
|
||||
if decoder_output is not None:
|
||||
num_plot = 4
|
||||
else:
|
||||
|
@ -60,18 +79,30 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
|
|||
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||
# compute phoneme representation and back
|
||||
if CONFIG.use_phonemes:
|
||||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
seq = phoneme_to_sequence(
|
||||
text, [CONFIG.text_cleaner],
|
||||
CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
text = sequence_to_phoneme(
|
||||
seq,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
print(text)
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
# plot stopnet predictions
|
||||
plt.subplot(num_plot, 1, 2)
|
||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||
|
||||
if stop_tokens is not None:
|
||||
# plot stopnet predictions
|
||||
plt.subplot(num_plot, 1, 2)
|
||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||
|
||||
# plot postnet spectrogram
|
||||
plt.subplot(num_plot, 1, 3)
|
||||
librosa.display.specshow(postnet_output.T, sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length, x_axis="time", y_axis="linear",
|
||||
librosa.display.specshow(postnet_output.T,
|
||||
sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear",
|
||||
fmin=CONFIG.audio['mel_fmin'],
|
||||
fmax=CONFIG.audio['mel_fmax'])
|
||||
|
||||
|
@ -82,8 +113,11 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG,
|
|||
|
||||
if decoder_output is not None:
|
||||
plt.subplot(num_plot, 1, 4)
|
||||
librosa.display.specshow(decoder_output.T, sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length, x_axis="time", y_axis="linear",
|
||||
librosa.display.specshow(decoder_output.T,
|
||||
sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear",
|
||||
fmin=CONFIG.audio['mel_fmin'],
|
||||
fmax=CONFIG.audio['mel_fmax'])
|
||||
plt.xlabel("Time", fontsize=label_fontsize)
|
||||
|
|
Loading…
Reference in New Issue