fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code

Fixes #2447, #2574
This commit is contained in:
Enno Hermann 2023-11-15 15:52:57 +01:00
parent aa0fbdf27e
commit ddbaecdb5b
1 changed files with 70 additions and 111 deletions

View File

@ -13,23 +13,27 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import torch\n",
"import importlib\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"from torch.utils.data import DataLoader\n",
"import soundfile as sf\n",
"import os\n",
"import pickle\n",
"\n",
"import numpy as np\n",
"import soundfile as sf\n",
"import torch\n",
"from matplotlib import pylab as plt\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"\n",
"from TTS.config import load_config\n",
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
"from TTS.tts.datasets import load_tts_samples\n",
"from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.config import load_config\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"\n",
"%matplotlib inline\n",
"\n",
@ -49,11 +53,9 @@
" file_name = wav_file.split('.')[0]\n",
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
" return file_name, wavq_path, mel_path, wav_path"
" return file_name, wavq_path, mel_path"
]
},
{
@ -65,14 +67,14 @@
"# Paths and configurations\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
"DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
"BATCH_SIZE = 32\n",
"\n",
"QUANTIZED_WAV = False\n",
"QUANTIZE_BIT = None\n",
"QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
"\n",
"# Check CUDA availability\n",
@ -80,10 +82,10 @@
"print(\" > CUDA enabled: \", use_cuda)\n",
"\n",
"# Load the configuration\n",
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
"C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
"print(C['r'])"
"ap = AudioProcessor(**C.audio)"
]
},
{
@ -92,12 +94,10 @@
"metadata": {},
"outputs": [],
"source": [
"# If the vocabulary was passed, replace the default\n",
"if 'characters' in C and C['characters']:\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"# Initialize the tokenizer\n",
"tokenizer, C = TTSTokenizer.init_from_config(C)\n",
"\n",
"# Load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speakers\n",
"model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
@ -109,42 +109,21 @@
"metadata": {},
"outputs": [],
"source": [
"# Load the preprocessor based on the dataset\n",
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
"# Load data instances\n",
"meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
"meta_data = meta_data_train + meta_data_eval\n",
"\n",
"dataset = TTSDataset(\n",
" C,\n",
" C.text_cleaner,\n",
" False,\n",
" ap,\n",
" meta_data,\n",
" characters=C.get('characters', None),\n",
" use_phonemes=C.use_phonemes,\n",
" phoneme_cache_path=C.phoneme_cache_path,\n",
" enable_eos_bos=C.enable_eos_bos_chars,\n",
" outputs_per_step=C[\"r\"],\n",
" compute_linear_spec=False,\n",
" ap=ap,\n",
" samples=meta_data,\n",
" tokenizer=tokenizer,\n",
" phoneme_cache_path=PHONEME_CACHE_PATH,\n",
")\n",
"loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Create log file\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"log_file = open(log_file_path, \"w\")"
")"
]
},
{
@ -160,26 +139,33 @@
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Start processing with a progress bar\n",
"with torch.no_grad():\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
" for data in tqdm(loader, desc=\"Processing\"):\n",
" try:\n",
" # setup input data\n",
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
"\n",
" # dispatch data to GPU\n",
" if use_cuda:\n",
" text_input = text_input.cuda()\n",
" text_lengths = text_lengths.cuda()\n",
" mel_input = mel_input.cuda()\n",
" mel_lengths = mel_lengths.cuda()\n",
" data[\"token_id\"] = data[\"token_id\"].cuda()\n",
" data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
" data[\"mel\"] = data[\"mel\"].cuda()\n",
" data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
"\n",
" mask = sequence_mask(text_lengths)\n",
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
" mask = sequence_mask(data[\"token_id_lengths\"])\n",
" outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
" mel_outputs = outputs[\"decoder_outputs\"]\n",
" postnet_outputs = outputs[\"model_outputs\"]\n",
"\n",
" # compute loss\n",
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
" loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n",
"\n",
@ -193,28 +179,27 @@
" postnet_outputs = torch.stack(mel_specs)\n",
" elif C.model == \"Tacotron2\":\n",
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
" alignments = alignments.detach().cpu().numpy()\n",
" alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
"\n",
" if not DRY_RUN:\n",
" for idx in range(text_input.shape[0]):\n",
" wav_file_path = item_idx[idx]\n",
" for idx in range(data[\"token_id\"].shape[0]):\n",
" wav_file_path = data[\"item_idxs\"][idx]\n",
" wav = ap.load_wav(wav_file_path)\n",
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n",
"\n",
" # quantize and save wav\n",
" if QUANTIZED_WAV:\n",
" wavq = ap.quantize(wav)\n",
" if QUANTIZE_BITS > 0:\n",
" wavq = ap.quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n",
"\n",
" # save TTS mel\n",
" mel = postnet_outputs[idx]\n",
" mel_length = mel_lengths[idx]\n",
" mel_length = data[\"mel_lengths\"][idx]\n",
" mel = mel[:mel_length, :].T\n",
" np.save(mel_path, mel)\n",
"\n",
" metadata.append([wav_file_path, mel_path])\n",
"\n",
" except Exception as e:\n",
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
"\n",
@ -224,35 +209,20 @@
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
"\n",
"# Close the log file\n",
"log_file.close()\n",
"\n",
"# For wavernn\n",
"if not DRY_RUN:\n",
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
"\n",
"# For pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
" for wav_file_path, mel_path in metadata:\n",
" f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
"\n",
"# Print mean losses\n",
"print(f\"Mean Loss: {mean_loss}\")\n",
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -267,7 +237,7 @@
"outputs": [],
"source": [
"idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
"ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
]
},
{
@ -276,10 +246,9 @@
"metadata": {},
"outputs": [],
"source": [
"import soundfile as sf\n",
"wav, sr = sf.read(item_idx[idx])\n",
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n",
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
"mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
"mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)"
]
@ -291,7 +260,7 @@
"outputs": [],
"source": [
"# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
"print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)"
]
},
@ -324,10 +293,9 @@
"outputs": [],
"source": [
"# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff = mel_decoder - mel_postnet\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
@ -339,10 +307,9 @@
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff2 = mel_truth.T - mel_decoder\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
@ -354,21 +321,13 @@
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel = postnet_outputs[idx]\n",
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {