Update notebook compat

This commit is contained in:
Eren Gölge 2021-09-09 10:59:23 +00:00
parent bfc6ceac29
commit 1de010acd4
1 changed files with 90 additions and 80 deletions

View File

@ -2,16 +2,14 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training."
]
"This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@ -25,22 +23,23 @@
"from TTS.tts.datasets.TTSDataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.io import load_config\n",
"from TTS.config import load_config\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.tts.utils.generic_utils import setup_model, sequence_mask\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
"\n",
"%matplotlib inline\n",
"\n",
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']='0'"
]
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def set_filename(wav_path, out_path):\n",
" wav_file = os.path.basename(wav_path)\n",
@ -52,20 +51,20 @@
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
" return file_name, wavq_path, mel_path, wav_path"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"OUT_PATH = \"/home/erogol/gdrive/Datasets/non-binary-voice-files/tacotron-DCA\"\n",
"DATA_PATH = \"/home/erogol/gdrive/Datasets/non-binary-voice-files/\"\n",
"DATASET = \"sam_accenture\"\n",
"METADATA_FILE = \"recording_script.xml\"\n",
"CONFIG_PATH = \"/home/erogol/gdrive/Trainings/sam/ljspeech-dcattn-April-03-2021_05+02-2344379/config.json\"\n",
"MODEL_FILE = \"/home/erogol/gdrive/Trainings/sam/ljspeech-dcattn-April-03-2021_05+02-2344379/best_model.pth.tar\"\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth.tar\"\n",
"BATCH_SIZE = 32\n",
"\n",
"QUANTIZED_WAV = False\n",
@ -78,56 +77,63 @@
"C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(C['r'])\n",
"# if the vocabulary was passed, replace the default\n",
"if 'characters' in C.keys():\n",
"if 'characters' in C and C['characters']:\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"\n",
"# load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speaker\n",
"model = setup_model(num_chars, num_speakers=0, c=C)\n",
"checkpoint = torch.load(MODEL_FILE)\n",
"model.load_state_dict(checkpoint['model'])\n",
"print(checkpoint['step'])\n",
"model.eval()\n",
"model.decoder.set_r(checkpoint['r'])\n",
"if use_cuda:\n",
" model = model.cuda()"
]
"model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')\n",
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
"dataset = TTSDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,characters=c.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
]
"dataset = TTSDataset(\n",
" checkpoint[\"config\"][\"r\"],\n",
" C.text_cleaner,\n",
" False,\n",
" ap,\n",
" meta_data,\n",
" characters=C.get('characters', None),\n",
" use_phonemes=C.use_phonemes,\n",
" phoneme_cache_path=C.phoneme_cache_path,\n",
" enable_eos_bos=C.enable_eos_bos_chars,\n",
")\n",
"loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate model outputs "
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
@ -206,42 +212,42 @@
"\n",
" print(np.mean(losses))\n",
" print(np.mean(postnet_losses))"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sanity Check"
]
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import soundfile as sf\n",
"wav, sr = sf.read(item_idx[idx])\n",
@ -249,46 +255,46 @@
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot decoder output\n",
"print(mel_decoder.shape)\n",
"plot_spectrogram(mel_decoder, ap)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot GT specgrogram\n",
"print(mel_truth.shape)\n",
"plot_spectrogram(mel_truth.T, ap)"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
@ -297,13 +303,13 @@
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
@ -312,13 +318,13 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
@ -328,21 +334,22 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"source": [],
"outputs": [],
"source": []
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
"name": "python3",
"display_name": "Python 3.9.7 64-bit ('base': conda)"
},
"language_info": {
"codemirror_mode": {
@ -354,7 +361,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.9.7"
},
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
}
},
"nbformat": 4,