From ddbaecdb5b8d9043075e4bd5ba73c587a5d89526 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 15 Nov 2023 15:52:57 +0100 Subject: [PATCH] fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code Fixes #2447, #2574 --- notebooks/ExtractTTSpectrogram.ipynb | 181 +++++++++++---------------- 1 file changed, 70 insertions(+), 111 deletions(-) diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index 9acc9929..076fb50c 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -13,23 +13,27 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import sys\n", - "import torch\n", "import importlib\n", - "import numpy as np\n", - "from tqdm import tqdm\n", - "from torch.utils.data import DataLoader\n", - "import soundfile as sf\n", + "import os\n", "import pickle\n", + "\n", + "import numpy as np\n", + "import soundfile as sf\n", + "import torch\n", + "from matplotlib import pylab as plt\n", + "from torch.utils.data import DataLoader\n", + "from tqdm import tqdm\n", + "\n", + "from TTS.config import load_config\n", + "from TTS.tts.configs.shared_configs import BaseDatasetConfig\n", + "from TTS.tts.datasets import load_tts_samples\n", "from TTS.tts.datasets.dataset import TTSDataset\n", "from TTS.tts.layers.losses import L1LossMasked\n", - "from TTS.utils.audio import AudioProcessor\n", - "from TTS.config import load_config\n", - "from TTS.tts.utils.visual import plot_spectrogram\n", - "from TTS.tts.utils.helpers import sequence_mask\n", "from TTS.tts.models import setup_model\n", - "from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", + "from TTS.tts.utils.helpers import sequence_mask\n", + "from TTS.tts.utils.text.tokenizer import TTSTokenizer\n", + "from TTS.tts.utils.visual import plot_spectrogram\n", + "from TTS.utils.audio import AudioProcessor\n", "\n", "%matplotlib inline\n", "\n", @@ -49,11 +53,9 @@ " file_name = wav_file.split('.')[0]\n", " os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n", - " os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n", " wavq_path = os.path.join(out_path, \"quant\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n", - " wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", - " return file_name, wavq_path, mel_path, wav_path" + " return file_name, wavq_path, mel_path" ] }, { @@ -65,14 +67,14 @@ "# Paths and configurations\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", + "PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n", "DATASET = \"ljspeech\"\n", "METADATA_FILE = \"metadata.csv\"\n", "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n", "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n", "BATCH_SIZE = 32\n", "\n", - "QUANTIZED_WAV = False\n", - "QUANTIZE_BIT = None\n", + "QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n", "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", "\n", "# Check CUDA availability\n", @@ -80,10 +82,10 @@ "print(\" > CUDA enabled: \", use_cuda)\n", "\n", "# Load the configuration\n", + "dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n", "C = load_config(CONFIG_PATH)\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", - "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n", - "print(C['r'])" + "ap = AudioProcessor(**C.audio)" ] }, { @@ -92,12 +94,10 @@ "metadata": {}, "outputs": [], "source": [ - "# If the vocabulary was passed, replace the default\n", - "if 'characters' in C and C['characters']:\n", - " symbols, phonemes = make_symbols(**C.characters)\n", + "# Initialize the tokenizer\n", + "tokenizer, C = TTSTokenizer.init_from_config(C)\n", "\n", "# Load the model\n", - "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "# TODO: multiple speakers\n", "model = setup_model(C)\n", "model.load_checkpoint(C, MODEL_FILE, eval=True)" @@ -109,42 +109,21 @@ "metadata": {}, "outputs": [], "source": [ - "# Load the preprocessor based on the dataset\n", - "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", - "preprocessor = getattr(preprocessor, DATASET.lower())\n", - "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", + "# Load data instances\n", + "meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n", + "meta_data = meta_data_train + meta_data_eval\n", + "\n", "dataset = TTSDataset(\n", - " C,\n", - " C.text_cleaner,\n", - " False,\n", - " ap,\n", - " meta_data,\n", - " characters=C.get('characters', None),\n", - " use_phonemes=C.use_phonemes,\n", - " phoneme_cache_path=C.phoneme_cache_path,\n", - " enable_eos_bos=C.enable_eos_bos_chars,\n", + " outputs_per_step=C[\"r\"],\n", + " compute_linear_spec=False,\n", + " ap=ap,\n", + " samples=meta_data,\n", + " tokenizer=tokenizer,\n", + " phoneme_cache_path=PHONEME_CACHE_PATH,\n", ")\n", "loader = DataLoader(\n", " dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Initialize lists for storing results\n", - "file_idxs = []\n", - "metadata = []\n", - "losses = []\n", - "postnet_losses = []\n", - "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", - "\n", - "# Create log file\n", - "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n", - "log_file = open(log_file_path, \"w\")" + ")" ] }, { @@ -160,26 +139,33 @@ "metadata": {}, "outputs": [], "source": [ + "# Initialize lists for storing results\n", + "file_idxs = []\n", + "metadata = []\n", + "losses = []\n", + "postnet_losses = []\n", + "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", + "\n", "# Start processing with a progress bar\n", - "with torch.no_grad():\n", + "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n", + "with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n", " for data in tqdm(loader, desc=\"Processing\"):\n", " try:\n", - " # setup input data\n", - " text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n", - "\n", " # dispatch data to GPU\n", " if use_cuda:\n", - " text_input = text_input.cuda()\n", - " text_lengths = text_lengths.cuda()\n", - " mel_input = mel_input.cuda()\n", - " mel_lengths = mel_lengths.cuda()\n", + " data[\"token_id\"] = data[\"token_id\"].cuda()\n", + " data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n", + " data[\"mel\"] = data[\"mel\"].cuda()\n", + " data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n", "\n", - " mask = sequence_mask(text_lengths)\n", - " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", + " mask = sequence_mask(data[\"token_id_lengths\"])\n", + " outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n", + " mel_outputs = outputs[\"decoder_outputs\"]\n", + " postnet_outputs = outputs[\"model_outputs\"]\n", "\n", " # compute loss\n", - " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", - " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", + " loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n", + " loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n", " losses.append(loss.item())\n", " postnet_losses.append(loss_postnet.item())\n", "\n", @@ -193,28 +179,27 @@ " postnet_outputs = torch.stack(mel_specs)\n", " elif C.model == \"Tacotron2\":\n", " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", - " alignments = alignments.detach().cpu().numpy()\n", + " alignments = outputs[\"alignments\"].detach().cpu().numpy()\n", "\n", " if not DRY_RUN:\n", - " for idx in range(text_input.shape[0]):\n", - " wav_file_path = item_idx[idx]\n", + " for idx in range(data[\"token_id\"].shape[0]):\n", + " wav_file_path = data[\"item_idxs\"][idx]\n", " wav = ap.load_wav(wav_file_path)\n", - " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", + " file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n", " file_idxs.append(file_name)\n", "\n", " # quantize and save wav\n", - " if QUANTIZED_WAV:\n", - " wavq = ap.quantize(wav)\n", + " if QUANTIZE_BITS > 0:\n", + " wavq = ap.quantize(wav, QUANTIZE_BITS)\n", " np.save(wavq_path, wavq)\n", "\n", " # save TTS mel\n", " mel = postnet_outputs[idx]\n", - " mel_length = mel_lengths[idx]\n", + " mel_length = data[\"mel_lengths\"][idx]\n", " mel = mel[:mel_length, :].T\n", " np.save(mel_path, mel)\n", "\n", " metadata.append([wav_file_path, mel_path])\n", - "\n", " except Exception as e:\n", " log_file.write(f\"Error processing data: {str(e)}\\n\")\n", "\n", @@ -224,35 +209,20 @@ " log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n", " log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n", "\n", - "# Close the log file\n", - "log_file.close()\n", - "\n", "# For wavernn\n", "if not DRY_RUN:\n", " pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n", "\n", "# For pwgan\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", - " for data in metadata:\n", - " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", + " for wav_file_path, mel_path in metadata:\n", + " f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n", "\n", "# Print mean losses\n", "print(f\"Mean Loss: {mean_loss}\")\n", "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# for pwgan\n", - "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", - " for data in metadata:\n", - " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -267,7 +237,7 @@ "outputs": [], "source": [ "idx = 1\n", - "ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" + "ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape" ] }, { @@ -276,10 +246,9 @@ "metadata": {}, "outputs": [], "source": [ - "import soundfile as sf\n", - "wav, sr = sf.read(item_idx[idx])\n", - "mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n", - "mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n", + "wav, sr = sf.read(data[\"item_idxs\"][idx])\n", + "mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n", + "mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n", "mel_truth = ap.melspectrogram(wav)\n", "print(mel_truth.shape)" ] @@ -291,7 +260,7 @@ "outputs": [], "source": [ "# plot posnet output\n", - "print(mel_postnet[:mel_lengths[idx], :].shape)\n", + "print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n", "plot_spectrogram(mel_postnet, ap)" ] }, @@ -324,10 +293,9 @@ "outputs": [], "source": [ "# postnet, decoder diff\n", - "from matplotlib import pylab as plt\n", "mel_diff = mel_decoder - mel_postnet\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] @@ -339,10 +307,9 @@ "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", - "from matplotlib import pylab as plt\n", "mel_diff2 = mel_truth.T - mel_decoder\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] @@ -354,21 +321,13 @@ "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", - "from matplotlib import pylab as plt\n", "mel = postnet_outputs[idx]\n", "mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {