mirror of https://github.com/coqui-ai/TTS.git
Update notebooks
This commit is contained in:
parent
3c2d500f53
commit
7d3d825904
|
@ -4,7 +4,27 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook runs the given model over all the benchmark sentences and show their alignment, stop-token prediction and spectrogram plot in order."
|
||||
"This is to test TTS models with benchmark sentences for speech synthesis.\n",
|
||||
"\n",
|
||||
"Before running this script please DON'T FORGET: \n",
|
||||
"- to set file paths.\n",
|
||||
"- to download related model files from TTS and WaveRNN.\n",
|
||||
"- to checkout right commit versions (given next to the model) of TTS and WaveRNN.\n",
|
||||
"- to set the right paths in the cell below.\n",
|
||||
"\n",
|
||||
"Repositories:\n",
|
||||
"- TTS: https://github.com/mozilla/TTS\n",
|
||||
"- WaveRNN: https://github.com/erogol/WaveRNN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TTS_PATH = \"/home/erogol/projects/\"\n",
|
||||
"WAVERNN_PATH =\"/home/erogol/projects/\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -28,7 +48,10 @@
|
|||
"\n",
|
||||
"%pylab inline\n",
|
||||
"rcParams[\"figure.figsize\"] = (16,5)\n",
|
||||
"# sys.path.append('/home/erogol/projects/') # set this if TTS is not installed globally\n",
|
||||
"\n",
|
||||
"# add libraries into environment\n",
|
||||
"sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n",
|
||||
"sys.path.append(WAVERNN_PATH) # set this if TTS is not installed globally\n",
|
||||
"\n",
|
||||
"import librosa\n",
|
||||
"import librosa.display\n",
|
||||
|
@ -54,7 +77,7 @@
|
|||
"source": [
|
||||
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
|
||||
" t_1 = time.time()\n",
|
||||
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap)\n",
|
||||
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, False, CONFIG.enable_eos_bos_chars)\n",
|
||||
" if not use_gl:\n",
|
||||
" waveform = wavernn.generate(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0).cuda(), batched=batched_wavernn, target=11000, overlap=550)\n",
|
||||
"\n",
|
||||
|
@ -76,8 +99,8 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"# Set constants\n",
|
||||
"ROOT_PATH = '/media/erogol/data_ssd/Data/models/ljspeech_models/4241_best_t2_model/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/best_model_edited.pth.tar'\n",
|
||||
"ROOT_PATH = '/media/erogol/data_ssd/Data/models/ljspeech_models/ljspeech-April-08-2019_07+32PM-8a47b46/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/checkpoint_260000.pth.tar'\n",
|
||||
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
||||
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
|
||||
"CONFIG = load_config(CONFIG_PATH)\n",
|
||||
|
@ -85,9 +108,9 @@
|
|||
"VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/wavernn/ljspeech/mold_ljspeech_best_model/config.json\"\n",
|
||||
"VOCODER_CONFIG = load_config(VOCODER_CONFIG_PATH)\n",
|
||||
"use_cuda = True\n",
|
||||
"windowing = True\n",
|
||||
"CONFIG.windowing = True\n",
|
||||
"use_gl = False # use GL if True\n",
|
||||
"batched_wavernn = False # use batched wavernn inference if True"
|
||||
"batched_wavernn = True # use batched wavernn inference if True"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -104,8 +127,6 @@
|
|||
"model = setup_model(num_chars, CONFIG)\n",
|
||||
"\n",
|
||||
"# load the audio processor\n",
|
||||
"# CONFIG.audio[\"power\"] = 1.3\n",
|
||||
"CONFIG.audio[\"preemphasis\"] = 0.97\n",
|
||||
"ap = AudioProcessor(**CONFIG.audio) \n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
@ -205,7 +226,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \"I'm sorry Dave. I'm afraid I can't do that .\"\n",
|
||||
"sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
|
@ -244,7 +265,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = \" Generative adversarial network or variational auto-encoder.\"\n",
|
||||
"sentence = \"Generative adversarial network or variational auto-encoder.\"\n",
|
||||
"align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)"
|
||||
]
|
||||
},
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,52 +0,0 @@
|
|||
import io
|
||||
import librosa
|
||||
import torch
|
||||
import numpy as np
|
||||
from TTS.utils.text import text_to_sequence
|
||||
from matplotlib import pylab as plt
|
||||
|
||||
hop_length = 250
|
||||
|
||||
|
||||
def create_speech(m, s, CONFIG, use_cuda, ap):
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
seq = np.array(text_to_sequence(s, text_cleaner))
|
||||
chars_var = torch.from_numpy(seq).unsqueeze(0)
|
||||
if use_cuda:
|
||||
chars_var = chars_var.cuda()
|
||||
mel_out, linear_out, alignments, stop_tokens = m.forward(chars_var.long())
|
||||
linear_out = linear_out[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
spec = ap._denormalize(linear_out)
|
||||
wav = ap.inv_spectrogram(linear_out.T)
|
||||
wav = wav[:ap.find_endpoint(wav)]
|
||||
out = io.BytesIO()
|
||||
ap.save_wav(wav, out)
|
||||
return wav, alignment, spec, stop_tokens
|
||||
|
||||
|
||||
def visualize(alignment, spectrogram, stop_tokens, CONFIG):
|
||||
label_fontsize = 16
|
||||
plt.figure(figsize=(16, 24))
|
||||
|
||||
plt.subplot(3, 1, 1)
|
||||
plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None)
|
||||
plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
|
||||
plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
|
||||
plt.colorbar()
|
||||
|
||||
stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy()
|
||||
plt.subplot(3, 1, 2)
|
||||
plt.plot(range(len(stop_tokens)), list(stop_tokens))
|
||||
|
||||
plt.subplot(3, 1, 3)
|
||||
librosa.display.specshow(
|
||||
spectrogram.T,
|
||||
sr=CONFIG.sample_rate,
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear")
|
||||
plt.xlabel("Time", fontsize=label_fontsize)
|
||||
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||
plt.tight_layout()
|
||||
plt.colorbar()
|
Loading…
Reference in New Issue