From 824c09120b6e6bf4be99d03cc07e611ecbf140fe Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 12 Apr 2019 16:14:38 +0200 Subject: [PATCH] Add notebook to extract spectrograms from trained TTS model --- notebooks/ExtractTTSpectrogram.ipynb | 274 +++++++++++++++++++++++++++ 1 file changed, 274 insertions(+) create mode 100644 notebooks/ExtractTTSpectrogram.ipynb diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb new file mode 100644 index 00000000..42b0ba48 --- /dev/null +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -0,0 +1,274 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TTS_PATH = \"/home/erogol/projects/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.append(TTS_PATH)\n", + "import torch\n", + "import importlib\n", + "import numpy as np\n", + "from tqdm import tqdm as tqdm\n", + "from torch.utils.data import DataLoader\n", + "from TTS.models.tacotron2 import Tacotron2\n", + "from TTS.datasets.TTSDataset import MyDataset\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.visual import plot_spectrogram\n", + "from TTS.utils.generic_utils import load_config\n", + "from TTS.datasets.preprocess import ljspeech\n", + "%matplotlib inline\n", + "\n", + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICES']='0'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def set_filename(wav_path, out_path):\n", + " wav_file = os.path.basename(wav_path)\n", + " file_name = wav_file.split('.')[0]\n", + " os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n", + " os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n", + " os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n", + " wavq_path = os.path.join(out_path, \"quant\", file_name)\n", + " mel_path = os.path.join(out_path, \"mel\", file_name)\n", + " wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", + " return file_name, wavq_path, mel_path, wav_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "OUT_PATH = \"/home/erogol/Data/LJSpeech-1.1/wavernn_4152/\"\n", + "DATA_PATH = \"/home/erogol/Data/LJSpeech-1.1/\"\n", + "METADATA_FILE = \"metadata_train.csv\"\n", + "CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/ljspeech_models/4258_nancy/config.json\"\n", + "MODEL_FILE = \"/home/erogol/checkpoint_92000.pth.tar\"\n", + "DRY_RUN = True # if False, does not generate output files, only computes loss and visuals.\n", + "BATCH_SIZE = 16\n", + "\n", + "use_cuda = torch.cuda.is_available()\n", + "\n", + "C = load_config(CONFIG_PATH)\n", + "ap = AudioProcessor(bits=9, **C.audio)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = MyDataset(DATA_PATH, METADATA_FILE, C.r, C.text_cleaner, ap, ljspeech, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path)\n", + "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from utils.text.symbols import symbols, phonemes\n", + "from utils.generic_utils import sequence_mask\n", + "from layers.losses import L1LossMasked\n", + "# load the model\n", + "MyModel = importlib.import_module('TTS.models.'+C.model.lower())\n", + "MyModel = getattr(MyModel, C.model)\n", + "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", + "model = MyModel(num_chars, C.r, attn_win=False)\n", + "checkpoint = torch.load(MODEL_FILE)\n", + "model.load_state_dict(checkpoint['model'])\n", + "print(checkpoint['step'])\n", + "model.eval()\n", + "if use_cuda:\n", + " model = model.cuda()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate model outputs " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle\n", + "\n", + "file_idxs = []\n", + "losses = []\n", + "postnet_losses = []\n", + "criterion = L1LossMasked()\n", + "for data in tqdm(loader):\n", + " # setup input data\n", + " text_input = data[0]\n", + " text_lengths = data[1]\n", + " linear_input = data[2]\n", + " mel_input = data[3]\n", + " mel_lengths = data[4]\n", + " stop_targets = data[5]\n", + " item_idx = data[6]\n", + " \n", + " # dispatch data to GPU\n", + " if use_cuda:\n", + " text_input = text_input.cuda()\n", + " text_lengths = text_lengths.cuda()\n", + " mel_input = mel_input.cuda()\n", + " mel_lengths = mel_lengths.cuda()\n", + "# linear_input = linear_input.cuda()\n", + " stop_targets = stop_targets.cuda()\n", + " \n", + " mask = sequence_mask(text_lengths)\n", + " mel_outputs, mel_postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input, mask)\n", + " \n", + " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", + " loss_postnet = criterion(mel_postnet_outputs, mel_input, mel_lengths)\n", + " losses.append(loss.item())\n", + " postnet_losses.append(loss_postnet.item())\n", + " if not DRY_RUN:\n", + " for idx in range(text_input.shape[0]):\n", + " wav_file_path = item_idx[idx]\n", + " wav = ap.load_wav(wav_file_path)\n", + " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", + " file_idxs.append(file_name)\n", + "\n", + " # quantize and save wav\n", + " wavq = ap.quantize(wav)\n", + " np.save(wavq_path, wavq)\n", + "\n", + " # save TTS mel\n", + " mel = mel_postnet_outputs[idx]\n", + " mel = mel.data.cpu().numpy()\n", + " mel_length = mel_lengths[idx]\n", + " mel = mel[:mel_length, :].T\n", + " np.save(mel_path, mel)\n", + "\n", + " # save GL voice\n", + " # wav_gen = ap.inv_mel_spectrogram(mel.T) # mel to wav\n", + " # wav_gen = ap.quantize(wav_gen)\n", + " # np.save(wav_path, wav_gen)\n", + "\n", + "if not DRY_RUN:\n", + " pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n", + " \n", + "\n", + "print(np.mean(losses))\n", + "print(np.mean(postnet_losses))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Check model performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "idx = 1\n", + "mel_example = mel_postnet_outputs[idx].data.cpu().numpy()\n", + "plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);\n", + "print(mel_example[:mel_lengths[1], :].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wav = ap.load_wav(item_idx[idx])\n", + "melt = ap.melspectrogram(wav)\n", + "print(melt.shape)\n", + "plot_spectrogram(melt.T, ap);" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pylab as plt\n", + "mel_diff = mel_outputs[idx] - mel_postnet_outputs[idx]\n", + "plt.figure(figsize=(16, 10))\n", + "plt.imshow(abs(mel_diff.detach().cpu().numpy()[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.colorbar()\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pylab as plt\n", + "# mel = mel_poutputs[idx].detach().cpu().numpy()\n", + "mel = mel_postnet_outputs[idx].detach().cpu().numpy()\n", + "mel_diff2 = melt.T - mel[:melt.shape[1]]\n", + "plt.figure(figsize=(16, 10))\n", + "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.colorbar()\n", + "plt.tight_layout()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}