From 40cb4a53a67fac2b13256b27acfa00ad2baba3f8 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 17 Mar 2020 13:24:30 +0100 Subject: [PATCH] update test attention notebooks --- notebooks/TestAttention.ipynb | 43 +++++++---------------------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb index 9d3e5e75..b350b070 100644 --- a/notebooks/TestAttention.ipynb +++ b/notebooks/TestAttention.ipynb @@ -40,19 +40,12 @@ "import IPython\n", "from IPython.display import Audio\n", "\n", - "os.environ['CUDA_VISIBLE_DEVICES']='2'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "os.environ['CUDA_VISIBLE_DEVICES']='1'\n", + "\n", "def tts(model, text, CONFIG, use_cuda, ap):\n", " t_1 = time.time()\n", " # run the model\n", - " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n", + " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, True)\n", " if CONFIG.model == \"Tacotron\" and not use_gl:\n", " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", " # plotting\n", @@ -66,18 +59,11 @@ " file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n", " out_path = os.path.join(OUT_FOLDER, file_name)\n", " ap.save_wav(waveform, out_path)\n", - " return attn_score" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + " return attn_score\n", + "\n", "# Set constants\n", - "ROOT_PATH = '/data/rw/pit/keep/ljspeech-December-11-2019_04+32PM-ca49ae8/'\n", - "MODEL_PATH = ROOT_PATH + '/checkpoint_410000.pth.tar'\n", + "ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/'\n", + "MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n", "CONFIG_PATH = ROOT_PATH + '/config.json'\n", "OUT_FOLDER = './hard_sentences/'\n", "CONFIG = load_config(CONFIG_PATH)\n", @@ -148,26 +134,13 @@ "outputs": [], "source": [ "model.decoder.max_decoder_steps=3000\n", - "model.decoder.prenet.train()\n", "attn_scores = []\n", "with open(SENTENCES_PATH, 'r') as f:\n", " for text in f:\n", - " try:\n", - " attn_score = tts(model, text, CONFIG, use_cuda, ap)\n", - " except ValueError:\n", - " attn_score = 0\n", + " attn_score = tts(model, text, CONFIG, use_cuda, ap)\n", " attn_scores.append(attn_score)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "np.mean(attn_scores)" - ] - }, { "cell_type": "code", "execution_count": null,