update test attention notebooks

This commit is contained in:
erogol 2020-03-17 13:24:30 +01:00
parent 25bcbe2887
commit 40cb4a53a6
1 changed files with 8 additions and 35 deletions

View File

@ -40,19 +40,12 @@
"import IPython\n",
"from IPython.display import Audio\n",
"\n",
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
"\n",
"def tts(model, text, CONFIG, use_cuda, ap):\n",
" t_1 = time.time()\n",
" # run the model\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, True)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" # plotting\n",
@ -66,18 +59,11 @@
" file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
" out_path = os.path.join(OUT_FOLDER, file_name)\n",
" ap.save_wav(waveform, out_path)\n",
" return attn_score"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
" return attn_score\n",
"\n",
"# Set constants\n",
"ROOT_PATH = '/data/rw/pit/keep/ljspeech-December-11-2019_04+32PM-ca49ae8/'\n",
"MODEL_PATH = ROOT_PATH + '/checkpoint_410000.pth.tar'\n",
"ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/'\n",
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = './hard_sentences/'\n",
"CONFIG = load_config(CONFIG_PATH)\n",
@ -148,26 +134,13 @@
"outputs": [],
"source": [
"model.decoder.max_decoder_steps=3000\n",
"model.decoder.prenet.train()\n",
"attn_scores = []\n",
"with open(SENTENCES_PATH, 'r') as f:\n",
" for text in f:\n",
" try:\n",
" attn_score = tts(model, text, CONFIG, use_cuda, ap)\n",
" except ValueError:\n",
" attn_score = 0\n",
" attn_score = tts(model, text, CONFIG, use_cuda, ap)\n",
" attn_scores.append(attn_score)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.mean(attn_scores)"
]
},
{
"cell_type": "code",
"execution_count": null,