{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a noteboook used to generate the speaker embeddings with the AngleProto speaker encoder model for multi-speaker training.\n",
    "\n",
    "Before running this script please DON'T FORGET: \n",
    "- to set file paths.\n",
    "- to download related model files from TTS.\n",
    "- download or clone related repos, linked below.\n",
    "- setup the repositories. ```python setup.py install```\n",
    "- to checkout right commit versions (given next to the model) of TTS.\n",
    "- to set the right paths in the cell below.\n",
    "\n",
    "Repository:\n",
    "- TTS: https://github.com/mozilla/TTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import importlib\n",
    "import random\n",
    "import librosa\n",
    "import torch\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from TTS.tts.utils.speakers import save_speaker_mapping, load_speaker_mapping\n",
    "\n",
    "# you may need to change this depending on your system\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0'\n",
    "\n",
    "\n",
    "from TTS.tts.utils.speakers import save_speaker_mapping, load_speaker_mapping\n",
    "from TTS.utils.audio import AudioProcessor\n",
    "from TTS.utils.io import load_config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should also adjust all the path constants to point at the relevant locations for you locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_RUN_PATH = \"../../Mozilla-TTS/checkpoints/libritts_100+360-angleproto-June-06-2020_04+12PM-9c04d1f/\"\n",
    "MODEL_PATH = MODEL_RUN_PATH + \"best_model.pth.tar\"\n",
    "CONFIG_PATH = MODEL_RUN_PATH + \"config.json\"\n",
    "\n",
    "\n",
    "DATASETS_NAME = ['vctk'] # list the datasets\n",
    "DATASETS_PATH = ['../../../datasets/VCTK/']\n",
    "DATASETS_METAFILE = ['']\n",
    "\n",
    "USE_CUDA = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Preprocess dataset\n",
    "meta_data = []\n",
    "for i in range(len(DATASETS_NAME)):\n",
    "    preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n",
    "    preprocessor = getattr(preprocessor, DATASETS_NAME[i].lower())\n",
    "    meta_data += preprocessor(DATASETS_PATH[i],DATASETS_METAFILE[i])\n",
    "      \n",
    "meta_data= list(meta_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = load_config(CONFIG_PATH)\n",
    "ap = AudioProcessor(**c['audio'])\n",
    "\n",
    "model = SpeakerEncoder(**c.model)\n",
    "model.load_state_dict(torch.load(MODEL_PATH)['model'])\n",
    "model.eval()\n",
    "if USE_CUDA:\n",
    "    model.cuda()\n",
    "\n",
    "embeddings_dict = {}\n",
    "len_meta_data= len(meta_data)\n",
    "\n",
    "for i in tqdm(range(len_meta_data)):\n",
    "    _, wav_file, speaker_id = meta_data[i]\n",
    "    wav_file_name = os.path.basename(wav_file)\n",
    "    mel_spec = ap.melspectrogram(ap.load_wav(wav_file)).T\n",
    "    mel_spec = torch.FloatTensor(mel_spec[None, :, :])\n",
    "    if USE_CUDA:\n",
    "        mel_spec = mel_spec.cuda()\n",
    "    embedd = model.compute_embedding(mel_spec).cpu().detach().numpy().reshape(-1)\n",
    "    embeddings_dict[wav_file_name] = [embedd,speaker_id]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create and export speakers.json\n",
    "speaker_mapping = {sample: {'name': embeddings_dict[sample][1], 'embedding':embeddings_dict[sample][0].reshape(-1).tolist()} for i, sample in enumerate(embeddings_dict.keys())}\n",
    "save_speaker_mapping(MODEL_RUN_PATH, speaker_mapping)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#test load integrity\n",
    "speaker_mapping_load = load_speaker_mapping(MODEL_RUN_PATH)\n",
    "assert speaker_mapping == speaker_mapping_load\n",
    "print(\"The file speakers.json has been exported to \",MODEL_RUN_PATH, ' with ', len(embeddings_dict.keys()), ' speakers')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}