diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb new file mode 100644 index 00000000..a1867d13 --- /dev/null +++ b/notebooks/TestAttention.ipynb @@ -0,0 +1,198 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook is to test attention performance on hard sentences taken from DeepVoice paper." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import os, sys\n", + "import torch \n", + "import time\n", + "import numpy as np\n", + "from matplotlib import pylab as plt\n", + "\n", + "%pylab inline\n", + "plt.rcParams[\"figure.figsize\"] = (16,5)\n", + "\n", + "import librosa\n", + "import librosa.display\n", + "\n", + "from TTS.layers import *\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.generic_utils import load_config, setup_model\n", + "from TTS.utils.text import text_to_sequence\n", + "from TTS.utils.synthesis import synthesis\n", + "from TTS.utils.visual import plot_alignment\n", + "from TTS.utils.measures import alignment_diagonal_score\n", + "\n", + "import IPython\n", + "from IPython.display import Audio\n", + "\n", + "os.environ['CUDA_VISIBLE_DEVICES']='2'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def tts(model, text, CONFIG, use_cuda, ap):\n", + " t_1 = time.time()\n", + " # run the model\n", + " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n", + " if CONFIG.model == \"Tacotron\" and not use_gl:\n", + " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", + " # plotting\n", + " attn_score = alignment_diagonal_score(torch.FloatTensor(alignment).unsqueeze(0))\n", + " print(f\" > {text}\")\n", + " IPython.display.display(IPython.display.Audio(waveform, rate=ap.sample_rate))\n", + " fig = plot_alignment(alignment, fig_size=(8, 5))\n", + " IPython.display.display(fig)\n", + " #saving results\n", + " os.makedirs(OUT_FOLDER, exist_ok=True)\n", + " file_name = text[:200].replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n", + " out_path = os.path.join(OUT_FOLDER, file_name)\n", + " ap.save_wav(waveform, out_path)\n", + " return attn_score" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set constants\n", + "ROOT_PATH = '/data/rw/pit/keep/ljspeech-December-11-2019_04+32PM-ca49ae8/'\n", + "MODEL_PATH = ROOT_PATH + '/checkpoint_410000.pth.tar'\n", + "CONFIG_PATH = ROOT_PATH + '/config.json'\n", + "OUT_FOLDER = './hard_sentences/'\n", + "CONFIG = load_config(CONFIG_PATH)\n", + "SENTENCES_PATH = 'sentences.txt'\n", + "use_cuda = True\n", + "\n", + "# Set some config fields manually for testing\n", + "# CONFIG.windowing = False\n", + "# CONFIG.prenet_dropout = False\n", + "# CONFIG.separate_stopnet = True\n", + "CONFIG.use_forward_attn = False\n", + "# CONFIG.forward_attn_mask = True\n", + "# CONFIG.stopnet = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LOAD TTS MODEL\n", + "from TTS.utils.text.symbols import symbols, phonemes\n", + "\n", + "# multi speaker \n", + "if CONFIG.use_speaker_embedding:\n", + " speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n", + " speakers_idx_to_id = {v: k for k, v in speakers.items()}\n", + "else:\n", + " speakers = []\n", + " speaker_id = None\n", + "\n", + "# load the model\n", + "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", + "model = setup_model(num_chars, len(speakers), CONFIG)\n", + "\n", + "# load the audio processor\n", + "ap = AudioProcessor(**CONFIG.audio) \n", + "\n", + "\n", + "# load model state\n", + "if use_cuda:\n", + " cp = torch.load(MODEL_PATH)\n", + "else:\n", + " cp = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)\n", + "\n", + "# load the model\n", + "model.load_state_dict(cp['model'])\n", + "if use_cuda:\n", + " model.cuda()\n", + "model.eval()\n", + "print(cp['step'])\n", + "print(cp['r'])\n", + "\n", + "# set model stepsize\n", + "if 'r' in cp:\n", + " model.decoder.set_r(cp['r'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.decoder.max_decoder_steps=3000\n", + "model.decoder.prenet.train()\n", + "attn_scores = []\n", + "with open(SENTENCES_PATH, 'r') as f:\n", + " for text in f:\n", + " try:\n", + " attn_score = tts(model, text, CONFIG, use_cuda, ap)\n", + " except ValueError:\n", + " attn_score = 0\n", + " attn_scores.append(attn_score)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.mean(attn_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.mean(attn_scores)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}