mirror of https://github.com/coqui-ai/TTS.git
update test attention notebooks
This commit is contained in:
parent
25bcbe2887
commit
40cb4a53a6
|
@ -40,19 +40,12 @@
|
|||
"import IPython\n",
|
||||
"from IPython.display import Audio\n",
|
||||
"\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='1'\n",
|
||||
"\n",
|
||||
"def tts(model, text, CONFIG, use_cuda, ap):\n",
|
||||
" t_1 = time.time()\n",
|
||||
" # run the model\n",
|
||||
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n",
|
||||
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, True)\n",
|
||||
" if CONFIG.model == \"Tacotron\" and not use_gl:\n",
|
||||
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
|
||||
" # plotting\n",
|
||||
|
@ -66,18 +59,11 @@
|
|||
" file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n",
|
||||
" out_path = os.path.join(OUT_FOLDER, file_name)\n",
|
||||
" ap.save_wav(waveform, out_path)\n",
|
||||
" return attn_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" return attn_score\n",
|
||||
"\n",
|
||||
"# Set constants\n",
|
||||
"ROOT_PATH = '/data/rw/pit/keep/ljspeech-December-11-2019_04+32PM-ca49ae8/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/checkpoint_410000.pth.tar'\n",
|
||||
"ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/'\n",
|
||||
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
|
||||
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
|
||||
"OUT_FOLDER = './hard_sentences/'\n",
|
||||
"CONFIG = load_config(CONFIG_PATH)\n",
|
||||
|
@ -148,26 +134,13 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"model.decoder.max_decoder_steps=3000\n",
|
||||
"model.decoder.prenet.train()\n",
|
||||
"attn_scores = []\n",
|
||||
"with open(SENTENCES_PATH, 'r') as f:\n",
|
||||
" for text in f:\n",
|
||||
" try:\n",
|
||||
" attn_score = tts(model, text, CONFIG, use_cuda, ap)\n",
|
||||
" except ValueError:\n",
|
||||
" attn_score = 0\n",
|
||||
" attn_score = tts(model, text, CONFIG, use_cuda, ap)\n",
|
||||
" attn_scores.append(attn_score)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.mean(attn_scores)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
|
Loading…
Reference in New Issue