Update notebook compat

This commit is contained in:
Eren Gölge 2021-09-09 10:59:23 +00:00
parent bfc6ceac29
commit 1de010acd4
1 changed files with 90 additions and 80 deletions

View File

@ -2,16 +2,14 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training." "This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training."
] ],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"%load_ext autoreload\n", "%load_ext autoreload\n",
"%autoreload 2\n", "%autoreload 2\n",
@ -25,22 +23,23 @@
"from TTS.tts.datasets.TTSDataset import TTSDataset\n", "from TTS.tts.datasets.TTSDataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.io import load_config\n", "from TTS.config import load_config\n",
"from TTS.tts.utils.visual import plot_spectrogram\n", "from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.tts.utils.generic_utils import setup_model, sequence_mask\n", "from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", "from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
"\n", "\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"\n", "\n",
"import os\n", "import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']='0'" "os.environ['CUDA_VISIBLE_DEVICES']='2'"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"def set_filename(wav_path, out_path):\n", "def set_filename(wav_path, out_path):\n",
" wav_file = os.path.basename(wav_path)\n", " wav_file = os.path.basename(wav_path)\n",
@ -52,20 +51,20 @@
" mel_path = os.path.join(out_path, \"mel\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", " wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
" return file_name, wavq_path, mel_path, wav_path" " return file_name, wavq_path, mel_path, wav_path"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"OUT_PATH = \"/home/erogol/gdrive/Datasets/non-binary-voice-files/tacotron-DCA\"\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/erogol/gdrive/Datasets/non-binary-voice-files/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"DATASET = \"sam_accenture\"\n", "DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"recording_script.xml\"\n", "METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/erogol/gdrive/Trainings/sam/ljspeech-dcattn-April-03-2021_05+02-2344379/config.json\"\n", "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/erogol/gdrive/Trainings/sam/ljspeech-dcattn-April-03-2021_05+02-2344379/best_model.pth.tar\"\n", "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth.tar\"\n",
"BATCH_SIZE = 32\n", "BATCH_SIZE = 32\n",
"\n", "\n",
"QUANTIZED_WAV = False\n", "QUANTIZED_WAV = False\n",
@ -78,56 +77,63 @@
"C = load_config(CONFIG_PATH)\n", "C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)" "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"print(C['r'])\n",
"# if the vocabulary was passed, replace the default\n", "# if the vocabulary was passed, replace the default\n",
"if 'characters' in C.keys():\n", "if 'characters' in C and C['characters']:\n",
" symbols, phonemes = make_symbols(**C.characters)\n", " symbols, phonemes = make_symbols(**C.characters)\n",
"\n", "\n",
"# load the model\n", "# load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speaker\n", "# TODO: multiple speaker\n",
"model = setup_model(num_chars, num_speakers=0, c=C)\n", "model = setup_model(C)\n",
"checkpoint = torch.load(MODEL_FILE)\n", "model.load_checkpoint(C, MODEL_FILE, eval=True)"
"model.load_state_dict(checkpoint['model'])\n", ],
"print(checkpoint['step'])\n", "outputs": [],
"model.eval()\n", "metadata": {}
"model.decoder.set_r(checkpoint['r'])\n",
"if use_cuda:\n",
" model = model.cuda()"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n", "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
"dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,characters=c.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", "dataset = TTSDataset(\n",
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" " checkpoint[\"config\"][\"r\"],\n",
] " C.text_cleaner,\n",
" False,\n",
" ap,\n",
" meta_data,\n",
" characters=C.get('characters', None),\n",
" use_phonemes=C.use_phonemes,\n",
" phoneme_cache_path=C.phoneme_cache_path,\n",
" enable_eos_bos=C.enable_eos_bos_chars,\n",
")\n",
"loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n"
],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"### Generate model outputs " "### Generate model outputs "
] ],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import pickle\n", "import pickle\n",
"\n", "\n",
@ -206,42 +212,42 @@
"\n", "\n",
" print(np.mean(losses))\n", " print(np.mean(losses))\n",
" print(np.mean(postnet_losses))" " print(np.mean(postnet_losses))"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# for pwgan\n", "# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n", " for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")" " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {},
"source": [ "source": [
"### Sanity Check" "### Sanity Check"
] ],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"idx = 1\n", "idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" "ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"import soundfile as sf\n", "import soundfile as sf\n",
"wav, sr = sf.read(item_idx[idx])\n", "wav, sr = sf.read(item_idx[idx])\n",
@ -249,46 +255,46 @@
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n", "mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n", "mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)" "print(mel_truth.shape)"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# plot posnet output\n", "# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n", "print(mel_postnet[:mel_lengths[idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)" "plot_spectrogram(mel_postnet, ap)"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# plot decoder output\n", "# plot decoder output\n",
"print(mel_decoder.shape)\n", "print(mel_decoder.shape)\n",
"plot_spectrogram(mel_decoder, ap)" "plot_spectrogram(mel_decoder, ap)"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# plot GT specgrogram\n", "# plot GT specgrogram\n",
"print(mel_truth.shape)\n", "print(mel_truth.shape)\n",
"plot_spectrogram(mel_truth.T, ap)" "plot_spectrogram(mel_truth.T, ap)"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# postnet, decoder diff\n", "# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n", "from matplotlib import pylab as plt\n",
@ -297,13 +303,13 @@
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n", "from matplotlib import pylab as plt\n",
@ -312,13 +318,13 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n", "from matplotlib import pylab as plt\n",
@ -328,21 +334,22 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ],
"outputs": [],
"metadata": {}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "source": [],
"outputs": [], "outputs": [],
"source": [] "metadata": {}
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "name": "python3",
"language": "python", "display_name": "Python 3.9.7 64-bit ('base': conda)"
"name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -354,7 +361,10 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.5" "version": "3.9.7"
},
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
} }
}, },
"nbformat": 4, "nbformat": 4,