fix the benchmark notebook after GL fix

This commit is contained in:
erogol 2020-02-19 18:17:10 +01:00
parent 6977899d07
commit f8ebf6abcd
1 changed files with 11 additions and 10 deletions

View File

@ -65,6 +65,7 @@
"from TTS.utils.text import text_to_sequence\n", "from TTS.utils.text import text_to_sequence\n",
"from TTS.utils.synthesis import synthesis\n", "from TTS.utils.synthesis import synthesis\n",
"from TTS.utils.visual import visualize\n", "from TTS.utils.visual import visualize\n",
"from TTS.utils.text.symbols import symbols, phonemes\n",
"\n", "\n",
"import IPython\n", "import IPython\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
@ -81,13 +82,15 @@
"source": [ "source": [
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n", "def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
" t_1 = time.time()\n", " t_1 = time.time()\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n", " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, style_wav=None, \n",
" truncated=False, enable_eos_bos_chars=CONFIG.enable_eos_bos_chars,\n",
" use_griffin_lim=use_gl)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n", " if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" # coorect the normalization differences b/w TTS and the Vocoder.\n", " # coorect the normalization differences b/w TTS and the Vocoder.\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n",
" mel_postnet_spec = ap_vocoder._normalize(mel_postnet_spec)\n",
" if not use_gl:\n", " if not use_gl:\n",
" mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n",
" mel_postnet_spec = ap_vocoder._normalize(mel_postnet_spec)\n",
" waveform = wavernn.generate(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0).cuda(), batched=batched_wavernn, target=8000, overlap=400)\n", " waveform = wavernn.generate(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0).cuda(), batched=batched_wavernn, target=8000, overlap=400)\n",
"\n", "\n",
" print(\" > Run-time: {}\".format(time.time() - t_1))\n", " print(\" > Run-time: {}\".format(time.time() - t_1))\n",
@ -108,7 +111,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Set constants\n", "# Set constants\n",
"ROOT_PATH = '/media/erogol/data_ssd/Models/libri_tts/5099/'\n", "ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-bn-December-23-2019_08+34AM-ffea133/'\n",
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n", "MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n", "CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n", "OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
@ -116,7 +119,7 @@
"VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/checkpoint_433000.pth.tar\"\n", "VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/checkpoint_433000.pth.tar\"\n",
"VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/config.json\"\n", "VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/config.json\"\n",
"VOCODER_CONFIG = load_config(VOCODER_CONFIG_PATH)\n", "VOCODER_CONFIG = load_config(VOCODER_CONFIG_PATH)\n",
"use_cuda = False\n", "use_cuda = True\n",
"\n", "\n",
"# Set some config fields manually for testing\n", "# Set some config fields manually for testing\n",
"# CONFIG.windowing = False\n", "# CONFIG.windowing = False\n",
@ -127,7 +130,7 @@
"# CONFIG.stopnet = True\n", "# CONFIG.stopnet = True\n",
"\n", "\n",
"# Set the vocoder\n", "# Set the vocoder\n",
"use_gl = False # use GL if True\n", "use_gl = True # use GL if True\n",
"batched_wavernn = True # use batched wavernn inference if True" "batched_wavernn = True # use batched wavernn inference if True"
] ]
}, },
@ -138,8 +141,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# LOAD TTS MODEL\n", "# LOAD TTS MODEL\n",
"from utils.text.symbols import symbols, phonemes\n",
"\n",
"# multi speaker \n", "# multi speaker \n",
"if CONFIG.use_speaker_embedding:\n", "if CONFIG.use_speaker_embedding:\n",
" speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n", " speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n",
@ -181,7 +182,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# LOAD WAVERNN\n", "# LOAD WAVERNN - Make sure you downloaded the model and installed the module\n",
"if use_gl == False:\n", "if use_gl == False:\n",
" from WaveRNN.models.wavernn import Model\n", " from WaveRNN.models.wavernn import Model\n",
" from WaveRNN.utils.audio import AudioProcessor as AudioProcessorVocoder\n", " from WaveRNN.utils.audio import AudioProcessor as AudioProcessorVocoder\n",
@ -533,7 +534,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.3" "version": "3.7.4"
} }
}, },
"nbformat": 4, "nbformat": 4,