mirror of https://github.com/coqui-ai/TTS.git
fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code
Fixes #2447, #2574
This commit is contained in:
parent
aa0fbdf27e
commit
ddbaecdb5b
|
@ -13,23 +13,27 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
|
||||||
"import sys\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import importlib\n",
|
"import importlib\n",
|
||||||
"import numpy as np\n",
|
"import os\n",
|
||||||
"from tqdm import tqdm\n",
|
|
||||||
"from torch.utils.data import DataLoader\n",
|
|
||||||
"import soundfile as sf\n",
|
|
||||||
"import pickle\n",
|
"import pickle\n",
|
||||||
|
"\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import soundfile as sf\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from matplotlib import pylab as plt\n",
|
||||||
|
"from torch.utils.data import DataLoader\n",
|
||||||
|
"from tqdm import tqdm\n",
|
||||||
|
"\n",
|
||||||
|
"from TTS.config import load_config\n",
|
||||||
|
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
|
||||||
|
"from TTS.tts.datasets import load_tts_samples\n",
|
||||||
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
||||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||||
"from TTS.utils.audio import AudioProcessor\n",
|
|
||||||
"from TTS.config import load_config\n",
|
|
||||||
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
|
||||||
"from TTS.tts.utils.helpers import sequence_mask\n",
|
|
||||||
"from TTS.tts.models import setup_model\n",
|
"from TTS.tts.models import setup_model\n",
|
||||||
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
|
"from TTS.tts.utils.helpers import sequence_mask\n",
|
||||||
|
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
|
||||||
|
"from TTS.tts.utils.visual import plot_spectrogram\n",
|
||||||
|
"from TTS.utils.audio import AudioProcessor\n",
|
||||||
"\n",
|
"\n",
|
||||||
"%matplotlib inline\n",
|
"%matplotlib inline\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -49,11 +53,9 @@
|
||||||
" file_name = wav_file.split('.')[0]\n",
|
" file_name = wav_file.split('.')[0]\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
|
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
|
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
|
||||||
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
|
|
||||||
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
|
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
|
||||||
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
|
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
|
||||||
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
|
" return file_name, wavq_path, mel_path"
|
||||||
" return file_name, wavq_path, mel_path, wav_path"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -65,14 +67,14 @@
|
||||||
"# Paths and configurations\n",
|
"# Paths and configurations\n",
|
||||||
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
||||||
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
||||||
|
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
|
||||||
"DATASET = \"ljspeech\"\n",
|
"DATASET = \"ljspeech\"\n",
|
||||||
"METADATA_FILE = \"metadata.csv\"\n",
|
"METADATA_FILE = \"metadata.csv\"\n",
|
||||||
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
|
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
|
||||||
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
|
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
|
||||||
"BATCH_SIZE = 32\n",
|
"BATCH_SIZE = 32\n",
|
||||||
"\n",
|
"\n",
|
||||||
"QUANTIZED_WAV = False\n",
|
"QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
|
||||||
"QUANTIZE_BIT = None\n",
|
|
||||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Check CUDA availability\n",
|
"# Check CUDA availability\n",
|
||||||
|
@ -80,10 +82,10 @@
|
||||||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Load the configuration\n",
|
"# Load the configuration\n",
|
||||||
|
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
|
||||||
"C = load_config(CONFIG_PATH)\n",
|
"C = load_config(CONFIG_PATH)\n",
|
||||||
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
||||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
|
"ap = AudioProcessor(**C.audio)"
|
||||||
"print(C['r'])"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -92,12 +94,10 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# If the vocabulary was passed, replace the default\n",
|
"# Initialize the tokenizer\n",
|
||||||
"if 'characters' in C and C['characters']:\n",
|
"tokenizer, C = TTSTokenizer.init_from_config(C)\n",
|
||||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Load the model\n",
|
"# Load the model\n",
|
||||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
|
||||||
"# TODO: multiple speakers\n",
|
"# TODO: multiple speakers\n",
|
||||||
"model = setup_model(C)\n",
|
"model = setup_model(C)\n",
|
||||||
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
||||||
|
@ -109,42 +109,21 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Load the preprocessor based on the dataset\n",
|
"# Load data instances\n",
|
||||||
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
|
"meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
|
||||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
"meta_data = meta_data_train + meta_data_eval\n",
|
||||||
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
|
"\n",
|
||||||
"dataset = TTSDataset(\n",
|
"dataset = TTSDataset(\n",
|
||||||
" C,\n",
|
" outputs_per_step=C[\"r\"],\n",
|
||||||
" C.text_cleaner,\n",
|
" compute_linear_spec=False,\n",
|
||||||
" False,\n",
|
" ap=ap,\n",
|
||||||
" ap,\n",
|
" samples=meta_data,\n",
|
||||||
" meta_data,\n",
|
" tokenizer=tokenizer,\n",
|
||||||
" characters=C.get('characters', None),\n",
|
" phoneme_cache_path=PHONEME_CACHE_PATH,\n",
|
||||||
" use_phonemes=C.use_phonemes,\n",
|
|
||||||
" phoneme_cache_path=C.phoneme_cache_path,\n",
|
|
||||||
" enable_eos_bos=C.enable_eos_bos_chars,\n",
|
|
||||||
")\n",
|
")\n",
|
||||||
"loader = DataLoader(\n",
|
"loader = DataLoader(\n",
|
||||||
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
|
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
|
||||||
")\n"
|
")"
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Initialize lists for storing results\n",
|
|
||||||
"file_idxs = []\n",
|
|
||||||
"metadata = []\n",
|
|
||||||
"losses = []\n",
|
|
||||||
"postnet_losses = []\n",
|
|
||||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
|
||||||
"\n",
|
|
||||||
"# Create log file\n",
|
|
||||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
|
||||||
"log_file = open(log_file_path, \"w\")"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -160,26 +139,33 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"# Initialize lists for storing results\n",
|
||||||
|
"file_idxs = []\n",
|
||||||
|
"metadata = []\n",
|
||||||
|
"losses = []\n",
|
||||||
|
"postnet_losses = []\n",
|
||||||
|
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||||
|
"\n",
|
||||||
"# Start processing with a progress bar\n",
|
"# Start processing with a progress bar\n",
|
||||||
"with torch.no_grad():\n",
|
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||||
|
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
|
||||||
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
||||||
" try:\n",
|
" try:\n",
|
||||||
" # setup input data\n",
|
|
||||||
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
|
|
||||||
"\n",
|
|
||||||
" # dispatch data to GPU\n",
|
" # dispatch data to GPU\n",
|
||||||
" if use_cuda:\n",
|
" if use_cuda:\n",
|
||||||
" text_input = text_input.cuda()\n",
|
" data[\"token_id\"] = data[\"token_id\"].cuda()\n",
|
||||||
" text_lengths = text_lengths.cuda()\n",
|
" data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
|
||||||
" mel_input = mel_input.cuda()\n",
|
" data[\"mel\"] = data[\"mel\"].cuda()\n",
|
||||||
" mel_lengths = mel_lengths.cuda()\n",
|
" data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" mask = sequence_mask(text_lengths)\n",
|
" mask = sequence_mask(data[\"token_id_lengths\"])\n",
|
||||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
" outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
|
||||||
|
" mel_outputs = outputs[\"decoder_outputs\"]\n",
|
||||||
|
" postnet_outputs = outputs[\"model_outputs\"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # compute loss\n",
|
" # compute loss\n",
|
||||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
" loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
" loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
|
||||||
" losses.append(loss.item())\n",
|
" losses.append(loss.item())\n",
|
||||||
" postnet_losses.append(loss_postnet.item())\n",
|
" postnet_losses.append(loss_postnet.item())\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -193,28 +179,27 @@
|
||||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||||
" elif C.model == \"Tacotron2\":\n",
|
" elif C.model == \"Tacotron2\":\n",
|
||||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||||
" alignments = alignments.detach().cpu().numpy()\n",
|
" alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if not DRY_RUN:\n",
|
" if not DRY_RUN:\n",
|
||||||
" for idx in range(text_input.shape[0]):\n",
|
" for idx in range(data[\"token_id\"].shape[0]):\n",
|
||||||
" wav_file_path = item_idx[idx]\n",
|
" wav_file_path = data[\"item_idxs\"][idx]\n",
|
||||||
" wav = ap.load_wav(wav_file_path)\n",
|
" wav = ap.load_wav(wav_file_path)\n",
|
||||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
" file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||||
" file_idxs.append(file_name)\n",
|
" file_idxs.append(file_name)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # quantize and save wav\n",
|
" # quantize and save wav\n",
|
||||||
" if QUANTIZED_WAV:\n",
|
" if QUANTIZE_BITS > 0:\n",
|
||||||
" wavq = ap.quantize(wav)\n",
|
" wavq = ap.quantize(wav, QUANTIZE_BITS)\n",
|
||||||
" np.save(wavq_path, wavq)\n",
|
" np.save(wavq_path, wavq)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # save TTS mel\n",
|
" # save TTS mel\n",
|
||||||
" mel = postnet_outputs[idx]\n",
|
" mel = postnet_outputs[idx]\n",
|
||||||
" mel_length = mel_lengths[idx]\n",
|
" mel_length = data[\"mel_lengths\"][idx]\n",
|
||||||
" mel = mel[:mel_length, :].T\n",
|
" mel = mel[:mel_length, :].T\n",
|
||||||
" np.save(mel_path, mel)\n",
|
" np.save(mel_path, mel)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" metadata.append([wav_file_path, mel_path])\n",
|
" metadata.append([wav_file_path, mel_path])\n",
|
||||||
"\n",
|
|
||||||
" except Exception as e:\n",
|
" except Exception as e:\n",
|
||||||
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -224,35 +209,20 @@
|
||||||
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
||||||
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Close the log file\n",
|
|
||||||
"log_file.close()\n",
|
|
||||||
"\n",
|
|
||||||
"# For wavernn\n",
|
"# For wavernn\n",
|
||||||
"if not DRY_RUN:\n",
|
"if not DRY_RUN:\n",
|
||||||
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# For pwgan\n",
|
"# For pwgan\n",
|
||||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||||
" for data in metadata:\n",
|
" for wav_file_path, mel_path in metadata:\n",
|
||||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
" f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print mean losses\n",
|
"# Print mean losses\n",
|
||||||
"print(f\"Mean Loss: {mean_loss}\")\n",
|
"print(f\"Mean Loss: {mean_loss}\")\n",
|
||||||
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# for pwgan\n",
|
|
||||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
|
||||||
" for data in metadata:\n",
|
|
||||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -267,7 +237,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"idx = 1\n",
|
"idx = 1\n",
|
||||||
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
|
"ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -276,10 +246,9 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import soundfile as sf\n",
|
"wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
|
||||||
"wav, sr = sf.read(item_idx[idx])\n",
|
"mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
|
||||||
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n",
|
"mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
|
||||||
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
|
|
||||||
"mel_truth = ap.melspectrogram(wav)\n",
|
"mel_truth = ap.melspectrogram(wav)\n",
|
||||||
"print(mel_truth.shape)"
|
"print(mel_truth.shape)"
|
||||||
]
|
]
|
||||||
|
@ -291,7 +260,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# plot posnet output\n",
|
"# plot posnet output\n",
|
||||||
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
|
"print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
|
||||||
"plot_spectrogram(mel_postnet, ap)"
|
"plot_spectrogram(mel_postnet, ap)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -324,10 +293,9 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# postnet, decoder diff\n",
|
"# postnet, decoder diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel_diff = mel_decoder - mel_postnet\n",
|
"mel_diff = mel_decoder - mel_postnet\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
|
@ -339,10 +307,9 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# PLOT GT SPECTROGRAM diff\n",
|
"# PLOT GT SPECTROGRAM diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel_diff2 = mel_truth.T - mel_decoder\n",
|
"mel_diff2 = mel_truth.T - mel_decoder\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
|
@ -354,21 +321,13 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# PLOT GT SPECTROGRAM diff\n",
|
"# PLOT GT SPECTROGRAM diff\n",
|
||||||
"from matplotlib import pylab as plt\n",
|
|
||||||
"mel = postnet_outputs[idx]\n",
|
"mel = postnet_outputs[idx]\n",
|
||||||
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
|
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
|
||||||
"plt.figure(figsize=(16, 10))\n",
|
"plt.figure(figsize=(16, 10))\n",
|
||||||
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
|
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
|
||||||
"plt.colorbar()\n",
|
"plt.colorbar()\n",
|
||||||
"plt.tight_layout()"
|
"plt.tight_layout()"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
Loading…
Reference in New Issue