From 1de010acd4881b18a881b905264964e05eb98d66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 9 Sep 2021 10:59:23 +0000 Subject: [PATCH] Update notebook compat --- notebooks/ExtractTTSpectrogram.ipynb | 170 ++++++++++++++------------- 1 file changed, 90 insertions(+), 80 deletions(-) diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index 4e42a3bb..3597ebe3 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -2,16 +2,14 @@ "cells": [ { "cell_type": "markdown", - "metadata": {}, "source": [ - "This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training." - ] + "This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training." + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -25,22 +23,23 @@ "from TTS.tts.datasets.TTSDataset import TTSDataset\n", "from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.utils.audio import AudioProcessor\n", - "from TTS.utils.io import load_config\n", + "from TTS.config import load_config\n", "from TTS.tts.utils.visual import plot_spectrogram\n", - "from TTS.tts.utils.generic_utils import setup_model, sequence_mask\n", + "from TTS.tts.utils.helpers import sequence_mask\n", + "from TTS.tts.models import setup_model\n", "from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "%matplotlib inline\n", "\n", "import os\n", - "os.environ['CUDA_VISIBLE_DEVICES']='0'" - ] + "os.environ['CUDA_VISIBLE_DEVICES']='2'" + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "def set_filename(wav_path, out_path):\n", " wav_file = os.path.basename(wav_path)\n", @@ -52,20 +51,20 @@ " mel_path = os.path.join(out_path, \"mel\", file_name)\n", " wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", " return file_name, wavq_path, mel_path, wav_path" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "OUT_PATH = \"/home/erogol/gdrive/Datasets/non-binary-voice-files/tacotron-DCA\"\n", - "DATA_PATH = \"/home/erogol/gdrive/Datasets/non-binary-voice-files/\"\n", - "DATASET = \"sam_accenture\"\n", - "METADATA_FILE = \"recording_script.xml\"\n", - "CONFIG_PATH = \"/home/erogol/gdrive/Trainings/sam/ljspeech-dcattn-April-03-2021_05+02-2344379/config.json\"\n", - "MODEL_FILE = \"/home/erogol/gdrive/Trainings/sam/ljspeech-dcattn-April-03-2021_05+02-2344379/best_model.pth.tar\"\n", + "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", + "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", + "DATASET = \"ljspeech\"\n", + "METADATA_FILE = \"metadata.csv\"\n", + "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n", + "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth.tar\"\n", "BATCH_SIZE = 32\n", "\n", "QUANTIZED_WAV = False\n", @@ -78,56 +77,63 @@ "C = load_config(CONFIG_PATH)\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ + "print(C['r'])\n", "# if the vocabulary was passed, replace the default\n", - "if 'characters' in C.keys():\n", + "if 'characters' in C and C['characters']:\n", " symbols, phonemes = make_symbols(**C.characters)\n", "\n", "# load the model\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "# TODO: multiple speaker\n", - "model = setup_model(num_chars, num_speakers=0, c=C)\n", - "checkpoint = torch.load(MODEL_FILE)\n", - "model.load_state_dict(checkpoint['model'])\n", - "print(checkpoint['step'])\n", - "model.eval()\n", - "model.decoder.set_r(checkpoint['r'])\n", - "if use_cuda:\n", - " model = model.cuda()" - ] + "model = setup_model(C)\n", + "model.load_checkpoint(C, MODEL_FILE, eval=True)" + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ - "preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n", + "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n", - "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", - "dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,characters=c.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", - "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" - ] + "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", + "dataset = TTSDataset(\n", + " checkpoint[\"config\"][\"r\"],\n", + " C.text_cleaner,\n", + " False,\n", + " ap,\n", + " meta_data,\n", + " characters=C.get('characters', None),\n", + " use_phonemes=C.use_phonemes,\n", + " phoneme_cache_path=C.phoneme_cache_path,\n", + " enable_eos_bos=C.enable_eos_bos_chars,\n", + ")\n", + "loader = DataLoader(\n", + " dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n", + ")\n" + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": {}, "source": [ "### Generate model outputs " - ] + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import pickle\n", "\n", @@ -206,42 +212,42 @@ "\n", " print(np.mean(losses))\n", " print(np.mean(postnet_losses))" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# for pwgan\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", " for data in metadata:\n", " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "markdown", - "metadata": {}, "source": [ "### Sanity Check" - ] + ], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "idx = 1\n", "ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "import soundfile as sf\n", "wav, sr = sf.read(item_idx[idx])\n", @@ -249,46 +255,46 @@ "mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n", "mel_truth = ap.melspectrogram(wav)\n", "print(mel_truth.shape)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# plot posnet output\n", "print(mel_postnet[:mel_lengths[idx], :].shape)\n", "plot_spectrogram(mel_postnet, ap)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# plot decoder output\n", "print(mel_decoder.shape)\n", "plot_spectrogram(mel_decoder, ap)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# plot GT specgrogram\n", "print(mel_truth.shape)\n", "plot_spectrogram(mel_truth.T, ap)" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# postnet, decoder diff\n", "from matplotlib import pylab as plt\n", @@ -297,13 +303,13 @@ "plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", "plt.colorbar()\n", "plt.tight_layout()" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", "from matplotlib import pylab as plt\n", @@ -312,13 +318,13 @@ "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.colorbar()\n", "plt.tight_layout()" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", "from matplotlib import pylab as plt\n", @@ -328,21 +334,22 @@ "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.colorbar()\n", "plt.tight_layout()" - ] + ], + "outputs": [], + "metadata": {} }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "source": [], "outputs": [], - "source": [] + "metadata": {} } ], "metadata": { "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" + "name": "python3", + "display_name": "Python 3.9.7 64-bit ('base': conda)" }, "language_info": { "codemirror_mode": { @@ -354,7 +361,10 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.9.7" + }, + "interpreter": { + "hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56" } }, "nbformat": 4,