update test attention notebooks

This commit is contained in:
erogol 2020-03-17 13:24:30 +01:00
parent 0ee1dd54a3
commit acccac72f5
1 changed files with 8 additions and 35 deletions

View File

@ -40,19 +40,12 @@
"import IPython\n", "import IPython\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
"\n", "\n",
"os.environ['CUDA_VISIBLE_DEVICES']='2'" "os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
] "\n",
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tts(model, text, CONFIG, use_cuda, ap):\n", "def tts(model, text, CONFIG, use_cuda, ap):\n",
" t_1 = time.time()\n", " t_1 = time.time()\n",
" # run the model\n", " # run the model\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n", " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, True)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n", " if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" # plotting\n", " # plotting\n",
@ -66,18 +59,11 @@
" file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n", " file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
" out_path = os.path.join(OUT_FOLDER, file_name)\n", " out_path = os.path.join(OUT_FOLDER, file_name)\n",
" ap.save_wav(waveform, out_path)\n", " ap.save_wav(waveform, out_path)\n",
" return attn_score" " return attn_score\n",
] "\n",
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set constants\n", "# Set constants\n",
"ROOT_PATH = '/data/rw/pit/keep/ljspeech-December-11-2019_04+32PM-ca49ae8/'\n", "ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/'\n",
"MODEL_PATH = ROOT_PATH + '/checkpoint_410000.pth.tar'\n", "MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n", "CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = './hard_sentences/'\n", "OUT_FOLDER = './hard_sentences/'\n",
"CONFIG = load_config(CONFIG_PATH)\n", "CONFIG = load_config(CONFIG_PATH)\n",
@ -148,26 +134,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"model.decoder.max_decoder_steps=3000\n", "model.decoder.max_decoder_steps=3000\n",
"model.decoder.prenet.train()\n",
"attn_scores = []\n", "attn_scores = []\n",
"with open(SENTENCES_PATH, 'r') as f:\n", "with open(SENTENCES_PATH, 'r') as f:\n",
" for text in f:\n", " for text in f:\n",
" try:\n",
" attn_score = tts(model, text, CONFIG, use_cuda, ap)\n", " attn_score = tts(model, text, CONFIG, use_cuda, ap)\n",
" except ValueError:\n",
" attn_score = 0\n",
" attn_scores.append(attn_score)" " attn_scores.append(attn_score)"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.mean(attn_scores)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,