mirror of https://github.com/coqui-ai/TTS.git
update test attention notebooks
This commit is contained in:
parent
0ee1dd54a3
commit
acccac72f5
|
@ -40,19 +40,12 @@
|
||||||
"import IPython\n",
|
"import IPython\n",
|
||||||
"from IPython.display import Audio\n",
|
"from IPython.display import Audio\n",
|
||||||
"\n",
|
"\n",
|
||||||
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
|
"os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
|
||||||
]
|
"\n",
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def tts(model, text, CONFIG, use_cuda, ap):\n",
|
"def tts(model, text, CONFIG, use_cuda, ap):\n",
|
||||||
" t_1 = time.time()\n",
|
" t_1 = time.time()\n",
|
||||||
" # run the model\n",
|
" # run the model\n",
|
||||||
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n",
|
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, True)\n",
|
||||||
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
|
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
|
||||||
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
|
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
|
||||||
" # plotting\n",
|
" # plotting\n",
|
||||||
|
@ -66,18 +59,11 @@
|
||||||
" file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
|
" file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
|
||||||
" out_path = os.path.join(OUT_FOLDER, file_name)\n",
|
" out_path = os.path.join(OUT_FOLDER, file_name)\n",
|
||||||
" ap.save_wav(waveform, out_path)\n",
|
" ap.save_wav(waveform, out_path)\n",
|
||||||
" return attn_score"
|
" return attn_score\n",
|
||||||
]
|
"\n",
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Set constants\n",
|
"# Set constants\n",
|
||||||
"ROOT_PATH = '/data/rw/pit/keep/ljspeech-December-11-2019_04+32PM-ca49ae8/'\n",
|
"ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/'\n",
|
||||||
"MODEL_PATH = ROOT_PATH + '/checkpoint_410000.pth.tar'\n",
|
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
|
||||||
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
||||||
"OUT_FOLDER = './hard_sentences/'\n",
|
"OUT_FOLDER = './hard_sentences/'\n",
|
||||||
"CONFIG = load_config(CONFIG_PATH)\n",
|
"CONFIG = load_config(CONFIG_PATH)\n",
|
||||||
|
@ -148,26 +134,13 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model.decoder.max_decoder_steps=3000\n",
|
"model.decoder.max_decoder_steps=3000\n",
|
||||||
"model.decoder.prenet.train()\n",
|
|
||||||
"attn_scores = []\n",
|
"attn_scores = []\n",
|
||||||
"with open(SENTENCES_PATH, 'r') as f:\n",
|
"with open(SENTENCES_PATH, 'r') as f:\n",
|
||||||
" for text in f:\n",
|
" for text in f:\n",
|
||||||
" try:\n",
|
|
||||||
" attn_score = tts(model, text, CONFIG, use_cuda, ap)\n",
|
" attn_score = tts(model, text, CONFIG, use_cuda, ap)\n",
|
||||||
" except ValueError:\n",
|
|
||||||
" attn_score = 0\n",
|
|
||||||
" attn_scores.append(attn_score)"
|
" attn_scores.append(attn_score)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"np.mean(attn_scores)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|
Loading…
Reference in New Issue