fix the benchmark notebook after GL fix

This commit is contained in:
erogol 2020-02-19 18:17:10 +01:00
parent 6977899d07
commit f8ebf6abcd
1 changed files with 11 additions and 10 deletions

View File

@ -65,6 +65,7 @@
"from TTS.utils.text import text_to_sequence\n",
"from TTS.utils.synthesis import synthesis\n",
"from TTS.utils.visual import visualize\n",
"from TTS.utils.text.symbols import symbols, phonemes\n",
"\n",
"import IPython\n",
"from IPython.display import Audio\n",
@ -81,13 +82,15 @@
"source": [
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
" t_1 = time.time()\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, style_wav=None, \n",
" truncated=False, enable_eos_bos_chars=CONFIG.enable_eos_bos_chars,\n",
" use_griffin_lim=use_gl)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" # coorect the normalization differences b/w TTS and the Vocoder.\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" if not use_gl:\n",
" mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n",
" mel_postnet_spec = ap_vocoder._normalize(mel_postnet_spec)\n",
" if not use_gl:\n",
" waveform = wavernn.generate(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0).cuda(), batched=batched_wavernn, target=8000, overlap=400)\n",
"\n",
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
@ -108,7 +111,7 @@
"outputs": [],
"source": [
"# Set constants\n",
"ROOT_PATH = '/media/erogol/data_ssd/Models/libri_tts/5099/'\n",
"ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-bn-December-23-2019_08+34AM-ffea133/'\n",
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
@ -116,7 +119,7 @@
"VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/checkpoint_433000.pth.tar\"\n",
"VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/config.json\"\n",
"VOCODER_CONFIG = load_config(VOCODER_CONFIG_PATH)\n",
"use_cuda = False\n",
"use_cuda = True\n",
"\n",
"# Set some config fields manually for testing\n",
"# CONFIG.windowing = False\n",
@ -127,7 +130,7 @@
"# CONFIG.stopnet = True\n",
"\n",
"# Set the vocoder\n",
"use_gl = False # use GL if True\n",
"use_gl = True # use GL if True\n",
"batched_wavernn = True # use batched wavernn inference if True"
]
},
@ -138,8 +141,6 @@
"outputs": [],
"source": [
"# LOAD TTS MODEL\n",
"from utils.text.symbols import symbols, phonemes\n",
"\n",
"# multi speaker \n",
"if CONFIG.use_speaker_embedding:\n",
" speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n",
@ -181,7 +182,7 @@
"metadata": {},
"outputs": [],
"source": [
"# LOAD WAVERNN\n",
"# LOAD WAVERNN - Make sure you downloaded the model and installed the module\n",
"if use_gl == False:\n",
" from WaveRNN.models.wavernn import Model\n",
" from WaveRNN.utils.audio import AudioProcessor as AudioProcessorVocoder\n",
@ -533,7 +534,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.7.4"
}
},
"nbformat": 4,