Update notebooks

This commit is contained in:
Eren Gölge 2021-10-21 16:20:14 +00:00
parent 5e0d0539c5
commit 016803beee
5 changed files with 243 additions and 69 deletions

View File

@ -2,14 +2,16 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training." "This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training."
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"%load_ext autoreload\n", "%load_ext autoreload\n",
"%autoreload 2\n", "%autoreload 2\n",
@ -20,7 +22,7 @@
"import numpy as np\n", "import numpy as np\n",
"from tqdm import tqdm as tqdm\n", "from tqdm import tqdm as tqdm\n",
"from torch.utils.data import DataLoader\n", "from torch.utils.data import DataLoader\n",
"from TTS.tts.datasets.TTSDataset import TTSDataset\n", "from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.audio import AudioProcessor\n",
"from TTS.config import load_config\n", "from TTS.config import load_config\n",
@ -33,13 +35,13 @@
"\n", "\n",
"import os\n", "import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']='2'" "os.environ['CUDA_VISIBLE_DEVICES']='2'"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"def set_filename(wav_path, out_path):\n", "def set_filename(wav_path, out_path):\n",
" wav_file = os.path.basename(wav_path)\n", " wav_file = os.path.basename(wav_path)\n",
@ -51,13 +53,13 @@
" mel_path = os.path.join(out_path, \"mel\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", " wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
" return file_name, wavq_path, mel_path, wav_path" " return file_name, wavq_path, mel_path, wav_path"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
@ -77,13 +79,13 @@
"C = load_config(CONFIG_PATH)\n", "C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)" "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"print(C['r'])\n", "print(C['r'])\n",
"# if the vocabulary was passed, replace the default\n", "# if the vocabulary was passed, replace the default\n",
@ -95,13 +97,13 @@
"# TODO: multiple speaker\n", "# TODO: multiple speaker\n",
"model = setup_model(C)\n", "model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)" "model.load_checkpoint(C, MODEL_FILE, eval=True)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n",
@ -120,20 +122,20 @@
"loader = DataLoader(\n", "loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n", " dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n" ")\n"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"### Generate model outputs " "### Generate model outputs "
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import pickle\n", "import pickle\n",
"\n", "\n",
@ -212,42 +214,42 @@
"\n", "\n",
" print(np.mean(losses))\n", " print(np.mean(losses))\n",
" print(np.mean(postnet_losses))" " print(np.mean(postnet_losses))"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# for pwgan\n", "# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n", " for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")" " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"### Sanity Check" "### Sanity Check"
], ]
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"idx = 1\n", "idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" "ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import soundfile as sf\n", "import soundfile as sf\n",
"wav, sr = sf.read(item_idx[idx])\n", "wav, sr = sf.read(item_idx[idx])\n",
@ -255,46 +257,46 @@
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n", "mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n", "mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)" "print(mel_truth.shape)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# plot posnet output\n", "# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n", "print(mel_postnet[:mel_lengths[idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)" "plot_spectrogram(mel_postnet, ap)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# plot decoder output\n", "# plot decoder output\n",
"print(mel_decoder.shape)\n", "print(mel_decoder.shape)\n",
"plot_spectrogram(mel_decoder, ap)" "plot_spectrogram(mel_decoder, ap)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# plot GT specgrogram\n", "# plot GT specgrogram\n",
"print(mel_truth.shape)\n", "print(mel_truth.shape)\n",
"plot_spectrogram(mel_truth.T, ap)" "plot_spectrogram(mel_truth.T, ap)"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# postnet, decoder diff\n", "# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n", "from matplotlib import pylab as plt\n",
@ -303,13 +305,13 @@
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n", "from matplotlib import pylab as plt\n",
@ -318,13 +320,13 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n", "from matplotlib import pylab as plt\n",
@ -334,22 +336,23 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
], ]
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"source": [], "metadata": {},
"outputs": [], "outputs": [],
"metadata": {} "source": []
} }
], ],
"metadata": { "metadata": {
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
},
"kernelspec": { "kernelspec": {
"name": "python3", "display_name": "Python 3.9.7 64-bit ('base': conda)",
"display_name": "Python 3.9.7 64-bit ('base': conda)" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -362,9 +365,6 @@
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.7" "version": "3.9.7"
},
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -19,19 +19,16 @@
"source": [ "source": [
"import os\n", "import os\n",
"import glob\n", "import glob\n",
"import random\n",
"import numpy as np\n", "import numpy as np\n",
"import torch\n",
"import umap\n", "import umap\n",
"\n", "\n",
"from TTS.speaker_encoder.model import SpeakerEncoder\n",
"from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.audio import AudioProcessor\n",
"from TTS.tts.utils.generic_utils import load_config\n", "from TTS.config import load_config\n",
"\n", "\n",
"from bokeh.io import output_notebook, show\n", "from bokeh.io import output_notebook, show\n",
"from bokeh.plotting import figure\n", "from bokeh.plotting import figure\n",
"from bokeh.models import HoverTool, ColumnDataSource, BoxZoomTool, ResetTool, OpenURL, TapTool\n", "from bokeh.models import HoverTool, ColumnDataSource, BoxZoomTool, ResetTool, OpenURL, TapTool\n",
"from bokeh.transform import factor_cmap, factor_mark\n", "from bokeh.transform import factor_cmap\n",
"from bokeh.palettes import Category10" "from bokeh.palettes import Category10"
] ]
}, },

View File

@ -22,7 +22,6 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n", "sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n",
"import glob\n",
"import librosa\n", "import librosa\n",
"import numpy as np\n", "import numpy as np\n",
"import pandas as pd\n", "import pandas as pd\n",

View File

@ -21,10 +21,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os, sys\n", "import os\n",
"import glob\n", "import glob\n",
"import subprocess\n", "import subprocess\n",
"import tempfile\n",
"import IPython\n", "import IPython\n",
"import soundfile as sf\n", "import soundfile as sf\n",
"import numpy as np\n", "import numpy as np\n",

File diff suppressed because one or more lines are too long