From 673ba74a802a00170ec057fbdd260863601fcf95 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 5 Aug 2020 12:59:40 +0200 Subject: [PATCH] glow tts training and inference fixes --- TTS/tts/utils/io.py | 2 +- TTS/tts/utils/synthesis.py | 4 +-- TTS/tts/utils/visual.py | 62 +++++++++++++++++++++++++++++--------- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index bf5e13d8..0749bd14 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -18,7 +18,7 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): if use_cuda: model.cuda() # set model stepsize - if 'r' in state.keys(): + if hasattr(model.decoder, 'r'): model.decoder.set_r(state['r']) return model, state diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 85eeec66..48083a2a 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -107,9 +107,9 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens): postnet_output = postnet_output[0].data.cpu().numpy() - decoder_output = decoder_output[0].data.cpu().numpy() + decoder_output = None if decoder_output is None else decoder_output[0].data.cpu().numpy() alignment = alignments[0].cpu().data.numpy() - stop_tokens = stop_tokens[0].cpu().numpy() + stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy() return postnet_output, decoder_output, alignment, stop_tokens diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 500d7707..033a5191 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -6,14 +6,20 @@ import matplotlib.pyplot as plt from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme -def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False): +def plot_alignment(alignment, + info=None, + fig_size=(16, 10), + title=None, + output_fig=False): if isinstance(alignment, torch.Tensor): alignment_ = alignment.detach().cpu().numpy().squeeze() else: alignment_ = alignment fig, ax = plt.subplots(figsize=fig_size) - im = ax.imshow( - alignment_.T, aspect='auto', origin='lower', interpolation='none') + im = ax.imshow(alignment_.T, + aspect='auto', + origin='lower', + interpolation='none') fig.colorbar(im, ax=ax) xlabel = 'Decoder timestep' if info is not None: @@ -29,7 +35,10 @@ def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_f return fig -def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): +def plot_spectrogram(spectrogram, + ap=None, + fig_size=(16, 10), + output_fig=False): if isinstance(spectrogram, torch.Tensor): spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T else: @@ -45,7 +54,17 @@ def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): return fig -def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24), output_fig=False): +def visualize(alignment, + postnet_output, + text, + hop_length, + CONFIG, + stop_tokens=None, + decoder_output=None, + output_path=None, + figsize=(8, 24), + output_fig=False): + if decoder_output is not None: num_plot = 4 else: @@ -60,18 +79,30 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, plt.ylabel("Encoder timestamp", fontsize=label_fontsize) # compute phoneme representation and back if CONFIG.use_phonemes: - seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) - text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) + seq = phoneme_to_sequence( + text, [CONFIG.text_cleaner], + CONFIG.phoneme_language, + CONFIG.enable_eos_bos_chars, + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) + text = sequence_to_phoneme( + seq, + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) print(text) plt.yticks(range(len(text)), list(text)) plt.colorbar() - # plot stopnet predictions - plt.subplot(num_plot, 1, 2) - plt.plot(range(len(stop_tokens)), list(stop_tokens)) + + if stop_tokens is not None: + # plot stopnet predictions + plt.subplot(num_plot, 1, 2) + plt.plot(range(len(stop_tokens)), list(stop_tokens)) + # plot postnet spectrogram plt.subplot(num_plot, 1, 3) - librosa.display.specshow(postnet_output.T, sr=CONFIG.audio['sample_rate'], - hop_length=hop_length, x_axis="time", y_axis="linear", + librosa.display.specshow(postnet_output.T, + sr=CONFIG.audio['sample_rate'], + hop_length=hop_length, + x_axis="time", + y_axis="linear", fmin=CONFIG.audio['mel_fmin'], fmax=CONFIG.audio['mel_fmax']) @@ -82,8 +113,11 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, if decoder_output is not None: plt.subplot(num_plot, 1, 4) - librosa.display.specshow(decoder_output.T, sr=CONFIG.audio['sample_rate'], - hop_length=hop_length, x_axis="time", y_axis="linear", + librosa.display.specshow(decoder_output.T, + sr=CONFIG.audio['sample_rate'], + hop_length=hop_length, + x_axis="time", + y_axis="linear", fmin=CONFIG.audio['mel_fmin'], fmax=CONFIG.audio['mel_fmax']) plt.xlabel("Time", fontsize=label_fontsize)