fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code

Fixes #2447, #2574
This commit is contained in:
Enno Hermann 2023-11-15 15:52:57 +01:00
parent aa0fbdf27e
commit ddbaecdb5b
1 changed files with 70 additions and 111 deletions

View File

@ -13,23 +13,27 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"import sys\n",
"import torch\n",
"import importlib\n", "import importlib\n",
"import numpy as np\n", "import os\n",
"from tqdm import tqdm\n",
"from torch.utils.data import DataLoader\n",
"import soundfile as sf\n",
"import pickle\n", "import pickle\n",
"\n",
"import numpy as np\n",
"import soundfile as sf\n",
"import torch\n",
"from matplotlib import pylab as plt\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"\n",
"from TTS.config import load_config\n",
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
"from TTS.tts.datasets import load_tts_samples\n",
"from TTS.tts.datasets.dataset import TTSDataset\n", "from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n", "from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.config import load_config\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.models import setup_model\n", "from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", "from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"\n", "\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"\n", "\n",
@ -49,11 +53,9 @@
" file_name = wav_file.split('.')[0]\n", " file_name = wav_file.split('.')[0]\n",
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n", " wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
" mel_path = os.path.join(out_path, \"mel\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", " return file_name, wavq_path, mel_path"
" return file_name, wavq_path, mel_path, wav_path"
] ]
}, },
{ {
@ -65,14 +67,14 @@
"# Paths and configurations\n", "# Paths and configurations\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
"DATASET = \"ljspeech\"\n", "DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"metadata.csv\"\n", "METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n", "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n", "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
"BATCH_SIZE = 32\n", "BATCH_SIZE = 32\n",
"\n", "\n",
"QUANTIZED_WAV = False\n", "QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
"QUANTIZE_BIT = None\n",
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
"\n", "\n",
"# Check CUDA availability\n", "# Check CUDA availability\n",
@ -80,10 +82,10 @@
"print(\" > CUDA enabled: \", use_cuda)\n", "print(\" > CUDA enabled: \", use_cuda)\n",
"\n", "\n",
"# Load the configuration\n", "# Load the configuration\n",
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
"C = load_config(CONFIG_PATH)\n", "C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n", "ap = AudioProcessor(**C.audio)"
"print(C['r'])"
] ]
}, },
{ {
@ -92,12 +94,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# If the vocabulary was passed, replace the default\n", "# Initialize the tokenizer\n",
"if 'characters' in C and C['characters']:\n", "tokenizer, C = TTSTokenizer.init_from_config(C)\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"\n", "\n",
"# Load the model\n", "# Load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speakers\n", "# TODO: multiple speakers\n",
"model = setup_model(C)\n", "model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)" "model.load_checkpoint(C, MODEL_FILE, eval=True)"
@ -109,42 +109,21 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Load the preprocessor based on the dataset\n", "# Load data instances\n",
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", "meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = meta_data_train + meta_data_eval\n",
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", "\n",
"dataset = TTSDataset(\n", "dataset = TTSDataset(\n",
" C,\n", " outputs_per_step=C[\"r\"],\n",
" C.text_cleaner,\n", " compute_linear_spec=False,\n",
" False,\n", " ap=ap,\n",
" ap,\n", " samples=meta_data,\n",
" meta_data,\n", " tokenizer=tokenizer,\n",
" characters=C.get('characters', None),\n", " phoneme_cache_path=PHONEME_CACHE_PATH,\n",
" use_phonemes=C.use_phonemes,\n",
" phoneme_cache_path=C.phoneme_cache_path,\n",
" enable_eos_bos=C.enable_eos_bos_chars,\n",
")\n", ")\n",
"loader = DataLoader(\n", "loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n", " dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n" ")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Create log file\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"log_file = open(log_file_path, \"w\")"
] ]
}, },
{ {
@ -160,26 +139,33 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Start processing with a progress bar\n", "# Start processing with a progress bar\n",
"with torch.no_grad():\n", "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
" for data in tqdm(loader, desc=\"Processing\"):\n", " for data in tqdm(loader, desc=\"Processing\"):\n",
" try:\n", " try:\n",
" # setup input data\n",
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
"\n",
" # dispatch data to GPU\n", " # dispatch data to GPU\n",
" if use_cuda:\n", " if use_cuda:\n",
" text_input = text_input.cuda()\n", " data[\"token_id\"] = data[\"token_id\"].cuda()\n",
" text_lengths = text_lengths.cuda()\n", " data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
" mel_input = mel_input.cuda()\n", " data[\"mel\"] = data[\"mel\"].cuda()\n",
" mel_lengths = mel_lengths.cuda()\n", " data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
"\n", "\n",
" mask = sequence_mask(text_lengths)\n", " mask = sequence_mask(data[\"token_id_lengths\"])\n",
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", " outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
" mel_outputs = outputs[\"decoder_outputs\"]\n",
" postnet_outputs = outputs[\"model_outputs\"]\n",
"\n", "\n",
" # compute loss\n", " # compute loss\n",
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n", " loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", " loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" losses.append(loss.item())\n", " losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n", " postnet_losses.append(loss_postnet.item())\n",
"\n", "\n",
@ -193,28 +179,27 @@
" postnet_outputs = torch.stack(mel_specs)\n", " postnet_outputs = torch.stack(mel_specs)\n",
" elif C.model == \"Tacotron2\":\n", " elif C.model == \"Tacotron2\":\n",
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
" alignments = alignments.detach().cpu().numpy()\n", " alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
"\n", "\n",
" if not DRY_RUN:\n", " if not DRY_RUN:\n",
" for idx in range(text_input.shape[0]):\n", " for idx in range(data[\"token_id\"].shape[0]):\n",
" wav_file_path = item_idx[idx]\n", " wav_file_path = data[\"item_idxs\"][idx]\n",
" wav = ap.load_wav(wav_file_path)\n", " wav = ap.load_wav(wav_file_path)\n",
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", " file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n", " file_idxs.append(file_name)\n",
"\n", "\n",
" # quantize and save wav\n", " # quantize and save wav\n",
" if QUANTIZED_WAV:\n", " if QUANTIZE_BITS > 0:\n",
" wavq = ap.quantize(wav)\n", " wavq = ap.quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n", " np.save(wavq_path, wavq)\n",
"\n", "\n",
" # save TTS mel\n", " # save TTS mel\n",
" mel = postnet_outputs[idx]\n", " mel = postnet_outputs[idx]\n",
" mel_length = mel_lengths[idx]\n", " mel_length = data[\"mel_lengths\"][idx]\n",
" mel = mel[:mel_length, :].T\n", " mel = mel[:mel_length, :].T\n",
" np.save(mel_path, mel)\n", " np.save(mel_path, mel)\n",
"\n", "\n",
" metadata.append([wav_file_path, mel_path])\n", " metadata.append([wav_file_path, mel_path])\n",
"\n",
" except Exception as e:\n", " except Exception as e:\n",
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n", " log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
"\n", "\n",
@ -224,35 +209,20 @@
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n", " log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n", " log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
"\n", "\n",
"# Close the log file\n",
"log_file.close()\n",
"\n",
"# For wavernn\n", "# For wavernn\n",
"if not DRY_RUN:\n", "if not DRY_RUN:\n",
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n", " pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
"\n", "\n",
"# For pwgan\n", "# For pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n", " for wav_file_path, mel_path in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", " f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
"\n", "\n",
"# Print mean losses\n", "# Print mean losses\n",
"print(f\"Mean Loss: {mean_loss}\")\n", "print(f\"Mean Loss: {mean_loss}\")\n",
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")" "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -267,7 +237,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"idx = 1\n", "idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" "ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
] ]
}, },
{ {
@ -276,10 +246,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import soundfile as sf\n", "wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
"wav, sr = sf.read(item_idx[idx])\n", "mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n", "mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n", "mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)" "print(mel_truth.shape)"
] ]
@ -291,7 +260,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# plot posnet output\n", "# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n", "print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)" "plot_spectrogram(mel_postnet, ap)"
] ]
}, },
@ -324,10 +293,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# postnet, decoder diff\n", "# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff = mel_decoder - mel_postnet\n", "mel_diff = mel_decoder - mel_postnet\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
@ -339,10 +307,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff2 = mel_truth.T - mel_decoder\n", "mel_diff2 = mel_truth.T - mel_decoder\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
@ -354,21 +321,13 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# PLOT GT SPECTROGRAM diff\n", "# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel = postnet_outputs[idx]\n", "mel = postnet_outputs[idx]\n",
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n", "mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
"plt.figure(figsize=(16, 10))\n", "plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n", "plt.colorbar()\n",
"plt.tight_layout()" "plt.tight_layout()"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {