From a510baa79c7a1dad6dc6f9748cdf82d97f528ccb Mon Sep 17 00:00:00 2001 From: root Date: Mon, 13 Jan 2020 14:53:56 +0100 Subject: [PATCH] notebook update --- notebooks/ExtractTTSpectrogram.ipynb | 216 ++++++++++++++------------- 1 file changed, 115 insertions(+), 101 deletions(-) diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index 2f188620..20038f78 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -32,16 +32,17 @@ "import numpy as np\n", "from tqdm import tqdm as tqdm\n", "from torch.utils.data import DataLoader\n", - "from TTS.models.tacotron2 import Tacotron2\n", "from TTS.datasets.TTSDataset import MyDataset\n", + "from TTS.layers.losses import L1LossMasked\n", "from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.visual import plot_spectrogram\n", - "from TTS.utils.generic_utils import load_config, setup_model\n", - "from TTS.datasets.preprocess import ljspeech\n", + "from TTS.utils.generic_utils import load_config, setup_model, sequence_mask\n", + "from TTS.utils.text.symbols import symbols, phonemes\n", + "\n", "%matplotlib inline\n", "\n", "import os\n", - "os.environ['CUDA_VISIBLE_DEVICES']='1'" + "os.environ['CUDA_VISIBLE_DEVICES']='2'" ] }, { @@ -68,22 +69,23 @@ "metadata": {}, "outputs": [], "source": [ - "OUT_PATH = \"/home/erogol/Data/Mozilla/wavernn/4841/\"\n", - "DATA_PATH = \"/home/erogol/Data/Mozilla/\"\n", - "DATASET = \"mozilla\"\n", + "OUT_PATH = \"/data/rw/pit/data/turkish-vocoder/\"\n", + "DATA_PATH = \"/data/rw/home/Turkish\"\n", + "DATASET = \"ljspeech\"\n", "METADATA_FILE = \"metadata.txt\"\n", - "CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/mozilla_models/4841/config.json\"\n", - "MODEL_FILE = \"/media/erogol/data_ssd/Data/models/mozilla_models/4841/best_model.pth.tar\"\n", - "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", + "CONFIG_PATH = \"/data/rw/pit/keep/turkish-January-08-2020_01+56AM-ca5e133/config.json\"\n", + "MODEL_FILE = \"/data/rw/pit/keep/turkish-January-08-2020_01+56AM-ca5e133/checkpoint_255000.pth.tar\"\n", "BATCH_SIZE = 32\n", "\n", + "QUANTIZED_WAV = False\n", + "QUANTIZE_BIT = 9\n", + "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", + "\n", "use_cuda = torch.cuda.is_available()\n", "print(\" > CUDA enabled: \", use_cuda)\n", "\n", "C = load_config(CONFIG_PATH)\n", - "ap = AudioProcessor(bits=9, **C.audio)\n", - "C.prenet_dropout = False\n", - "C.separate_stopnet = True" + "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)" ] }, { @@ -92,35 +94,32 @@ "metadata": {}, "outputs": [], "source": [ - "preprocessor = importlib.import_module('datasets.preprocess')\n", - "preprocessor = getattr(preprocessor, DATASET.lower())\n", - "\n", - "dataset = MyDataset(DATA_PATH, METADATA_FILE, C.r, C.text_cleaner, ap, preprocessor, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path)\n", - "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from TTS.utils.text.symbols import symbols, phonemes\n", - "from TTS.utils.generic_utils import sequence_mask\n", - "from TTS.layers.losses import L1LossMasked\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", - "\n", "# load the model\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", - "model = setup_model(num_chars, C)\n", + "# TODO: multiple speaker\n", + "model = setup_model(num_chars, num_speakers=0, c=C)\n", "checkpoint = torch.load(MODEL_FILE)\n", "model.load_state_dict(checkpoint['model'])\n", "print(checkpoint['step'])\n", "model.eval()\n", + "model.decoder.set_r(checkpoint['r'])\n", "if use_cuda:\n", " model = model.cuda()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preprocessor = importlib.import_module('TTS.datasets.preprocess')\n", + "preprocessor = getattr(preprocessor, DATASET.lower())\n", + "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", + "dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", + "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -137,73 +136,92 @@ "import pickle\n", "\n", "file_idxs = []\n", + "metadata = []\n", "losses = []\n", "postnet_losses = []\n", "criterion = L1LossMasked()\n", - "for data in tqdm(loader):\n", - " # setup input data\n", - " text_input = data[0]\n", - " text_lengths = data[1]\n", - " linear_input = data[2]\n", - " mel_input = data[3]\n", - " mel_lengths = data[4]\n", - " stop_targets = data[5]\n", - " item_idx = data[6]\n", - " \n", - " # dispatch data to GPU\n", - " if use_cuda:\n", - " text_input = text_input.cuda()\n", - " text_lengths = text_lengths.cuda()\n", - " mel_input = mel_input.cuda()\n", - " mel_lengths = mel_lengths.cuda()\n", - "# linear_input = linear_input.cuda()\n", - " stop_targets = stop_targets.cuda()\n", - " \n", - " mask = sequence_mask(text_lengths)\n", - " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", - " \n", - " # compute mel specs from linear spec if model is Tacotron\n", - " mel_specs = []\n", - " if C.model == \"Tacotron\":\n", - " postnet_outputs = postnet_outputs.data.cpu().numpy()\n", - " for b in range(postnet_outputs.shape[0]):\n", - " postnet_output = postnet_outputs[b]\n", - " mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n", - " postnet_outputs = torch.stack(mel_specs)\n", - " \n", - " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", - " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", - " losses.append(loss.item())\n", - " postnet_losses.append(loss_postnet.item())\n", + "with torch.no_grad():\n", + " for data in tqdm(loader):\n", + " # setup input data\n", + " text_input = data[0]\n", + " text_lengths = data[1]\n", + " linear_input = data[3]\n", + " mel_input = data[4]\n", + " mel_lengths = data[5]\n", + " stop_targets = data[6]\n", + " item_idx = data[7]\n", + "\n", + " # dispatch data to GPU\n", + " if use_cuda:\n", + " text_input = text_input.cuda()\n", + " text_lengths = text_lengths.cuda()\n", + " mel_input = mel_input.cuda()\n", + " mel_lengths = mel_lengths.cuda()\n", + "\n", + " mask = sequence_mask(text_lengths)\n", + " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", + " \n", + " # compute loss\n", + " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", + " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", + " losses.append(loss.item())\n", + " postnet_losses.append(loss_postnet.item())\n", + "\n", + " # compute mel specs from linear spec if model is Tacotron\n", + " if C.model == \"Tacotron\":\n", + " mel_specs = []\n", + " postnet_outputs = postnet_outputs.data.cpu().numpy()\n", + " for b in range(postnet_outputs.shape[0]):\n", + " postnet_output = postnet_outputs[b]\n", + " mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n", + " postnet_outputs = torch.stack(mel_specs)\n", + " elif C.model == \"Tacotron2\":\n", + " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", + " alignments = alignments.detach().cpu().numpy()\n", + "\n", + " if not DRY_RUN:\n", + " for idx in range(text_input.shape[0]):\n", + " wav_file_path = item_idx[idx]\n", + " wav = ap.load_wav(wav_file_path)\n", + " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", + " file_idxs.append(file_name)\n", + "\n", + " # quantize and save wav\n", + " if QUANTIZED_WAV:\n", + " wavq = ap.quantize(wav)\n", + " np.save(wavq_path, wavq)\n", + "\n", + " # save TTS mel\n", + " mel = postnet_outputs[idx]\n", + " mel_length = mel_lengths[idx]\n", + " mel = mel[:mel_length, :].T\n", + " np.save(mel_path, mel)\n", + "\n", + " metadata.append([wav_file_path, mel_path])\n", + "\n", + " # for wavernn\n", " if not DRY_RUN:\n", - " for idx in range(text_input.shape[0]):\n", - " wav_file_path = item_idx[idx]\n", - " wav = ap.load_wav(wav_file_path)\n", - " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", - " file_idxs.append(file_name)\n", - "\n", - "# # quantize and save wav\n", - "# wavq = ap.quantize(wav)\n", - "# np.save(wavq_path, wavq)\n", - "\n", - " # save TTS mel\n", - " mel = postnet_outputs[idx]\n", - " mel = mel.data.cpu().numpy()\n", - " mel_length = mel_lengths[idx]\n", - " mel = mel[:mel_length, :].T\n", - " np.save(mel_path, mel)\n", - "\n", - " # save GL voice\n", - " # wav_gen = ap.inv_mel_spectrogram(mel.T) # mel to wav\n", - " # wav_gen = ap.quantize(wav_gen)\n", - " # np.save(wav_path, wav_gen)\n", - "\n", - "if not DRY_RUN:\n", - " pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n", + " pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n", " \n", + " # for pwgan\n", + " with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", + " for data in metadata:\n", + " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", "\n", - "print(np.mean(losses))\n", - "print(np.mean(postnet_losses))" + " print(np.mean(losses))\n", + " print(np.mean(postnet_losses))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# for pwgan\n", + "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", + " for data in metadata:\n", + " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")" ] }, { @@ -219,8 +237,9 @@ "metadata": {}, "outputs": [], "source": [ + "# plot posnet output\n", "idx = 1\n", - "mel_example = postnet_outputs[idx].data.cpu().numpy()\n", + "mel_example = postnet_outputs[idx]\n", "plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);\n", "print(mel_example[:mel_lengths[1], :].shape)" ] @@ -231,6 +250,7 @@ "metadata": {}, "outputs": [], "source": [ + "# plot decoder output\n", "mel_example = mel_outputs[idx].data.cpu().numpy()\n", "plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);\n", "print(mel_example[:mel_lengths[1], :].shape)" @@ -242,6 +262,7 @@ "metadata": {}, "outputs": [], "source": [ + "# plot GT specgrogram\n", "wav = ap.load_wav(item_idx[idx])\n", "melt = ap.melspectrogram(wav)\n", "print(melt.shape)\n", @@ -278,13 +299,6 @@ "plt.colorbar()\n", "plt.tight_layout()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {