mirror of https://github.com/coqui-ai/TTS.git
Upgrade and Optimize TTS Code in extractttsspectrogram.ipynb (#3012)
This commit is contained in:
parent
155c5fc0bd
commit
f133b9d2d7
|
@ -13,15 +13,15 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"import os\n",
|
||||
"import sys\n",
|
||||
"import torch\n",
|
||||
"import importlib\n",
|
||||
"import numpy as np\n",
|
||||
"from tqdm import tqdm as tqdm\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import soundfile as sf\n",
|
||||
"import pickle\n",
|
||||
"from TTS.tts.datasets.dataset import TTSDataset\n",
|
||||
"from TTS.tts.layers.losses import L1LossMasked\n",
|
||||
"from TTS.utils.audio import AudioProcessor\n",
|
||||
|
@ -33,8 +33,8 @@
|
|||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
|
||||
"# Configure CUDA visibility\n",
|
||||
"os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -43,6 +43,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Function to create directories and file names\n",
|
||||
"def set_filename(wav_path, out_path):\n",
|
||||
" wav_file = os.path.basename(wav_path)\n",
|
||||
" file_name = wav_file.split('.')[0]\n",
|
||||
|
@ -61,6 +62,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Paths and configurations\n",
|
||||
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
|
||||
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
|
||||
"DATASET = \"ljspeech\"\n",
|
||||
|
@ -73,12 +75,15 @@
|
|||
"QUANTIZE_BIT = None\n",
|
||||
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
|
||||
"\n",
|
||||
"# Check CUDA availability\n",
|
||||
"use_cuda = torch.cuda.is_available()\n",
|
||||
"print(\" > CUDA enabled: \", use_cuda)\n",
|
||||
"\n",
|
||||
"# Load the configuration\n",
|
||||
"C = load_config(CONFIG_PATH)\n",
|
||||
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
|
||||
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
|
||||
"print(C['r'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -87,14 +92,13 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(C['r'])\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"# If the vocabulary was passed, replace the default\n",
|
||||
"if 'characters' in C and C['characters']:\n",
|
||||
" symbols, phonemes = make_symbols(**C.characters)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"# Load the model\n",
|
||||
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
|
||||
"# TODO: multiple speaker\n",
|
||||
"# TODO: multiple speakers\n",
|
||||
"model = setup_model(C)\n",
|
||||
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
|
||||
]
|
||||
|
@ -105,11 +109,12 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the preprocessor based on the dataset\n",
|
||||
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
|
||||
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
|
||||
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
|
||||
"dataset = TTSDataset(\n",
|
||||
" checkpoint[\"config\"][\"r\"],\n",
|
||||
" C,\n",
|
||||
" C.text_cleaner,\n",
|
||||
" False,\n",
|
||||
" ap,\n",
|
||||
|
@ -124,6 +129,24 @@
|
|||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize lists for storing results\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"\n",
|
||||
"# Create log file\n",
|
||||
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
|
||||
"log_file = open(log_file_path, \"w\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -137,83 +160,85 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pickle\n",
|
||||
"\n",
|
||||
"file_idxs = []\n",
|
||||
"metadata = []\n",
|
||||
"losses = []\n",
|
||||
"postnet_losses = []\n",
|
||||
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
|
||||
"# Start processing with a progress bar\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for data in tqdm(loader):\n",
|
||||
" # setup input data\n",
|
||||
" text_input = data[0]\n",
|
||||
" text_lengths = data[1]\n",
|
||||
" linear_input = data[3]\n",
|
||||
" mel_input = data[4]\n",
|
||||
" mel_lengths = data[5]\n",
|
||||
" stop_targets = data[6]\n",
|
||||
" item_idx = data[7]\n",
|
||||
" for data in tqdm(loader, desc=\"Processing\"):\n",
|
||||
" try:\n",
|
||||
" # setup input data\n",
|
||||
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
|
||||
"\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
" # dispatch data to GPU\n",
|
||||
" if use_cuda:\n",
|
||||
" text_input = text_input.cuda()\n",
|
||||
" text_lengths = text_lengths.cuda()\n",
|
||||
" mel_input = mel_input.cuda()\n",
|
||||
" mel_lengths = mel_lengths.cuda()\n",
|
||||
"\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
" \n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
" mask = sequence_mask(text_lengths)\n",
|
||||
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
|
||||
"\n",
|
||||
" # compute mel specs from linear spec if model is Tacotron\n",
|
||||
" if C.model == \"Tacotron\":\n",
|
||||
" mel_specs = []\n",
|
||||
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
|
||||
" for b in range(postnet_outputs.shape[0]):\n",
|
||||
" postnet_output = postnet_outputs[b]\n",
|
||||
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
|
||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
" # compute loss\n",
|
||||
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
|
||||
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
" postnet_losses.append(loss_postnet.item())\n",
|
||||
"\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
" # compute mel specs from linear spec if the model is Tacotron\n",
|
||||
" if C.model == \"Tacotron\":\n",
|
||||
" mel_specs = []\n",
|
||||
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
|
||||
" for b in range(postnet_outputs.shape[0]):\n",
|
||||
" postnet_output = postnet_outputs[b]\n",
|
||||
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
|
||||
" postnet_outputs = torch.stack(mel_specs)\n",
|
||||
" elif C.model == \"Tacotron2\":\n",
|
||||
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
|
||||
" alignments = alignments.detach().cpu().numpy()\n",
|
||||
"\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" for idx in range(text_input.shape[0]):\n",
|
||||
" wav_file_path = item_idx[idx]\n",
|
||||
" wav = ap.load_wav(wav_file_path)\n",
|
||||
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
|
||||
" file_idxs.append(file_name)\n",
|
||||
"\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
" # quantize and save wav\n",
|
||||
" if QUANTIZED_WAV:\n",
|
||||
" wavq = ap.quantize(wav)\n",
|
||||
" np.save(wavq_path, wavq)\n",
|
||||
"\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
" # save TTS mel\n",
|
||||
" mel = postnet_outputs[idx]\n",
|
||||
" mel_length = mel_lengths[idx]\n",
|
||||
" mel = mel[:mel_length, :].T\n",
|
||||
" np.save(mel_path, mel)\n",
|
||||
"\n",
|
||||
" # for wavernn\n",
|
||||
" if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n",
|
||||
" \n",
|
||||
" # for pwgan\n",
|
||||
" with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
" metadata.append([wav_file_path, mel_path])\n",
|
||||
"\n",
|
||||
" print(np.mean(losses))\n",
|
||||
" print(np.mean(postnet_losses))"
|
||||
" except Exception as e:\n",
|
||||
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
|
||||
"\n",
|
||||
" # Calculate and log mean losses\n",
|
||||
" mean_loss = np.mean(losses)\n",
|
||||
" mean_postnet_loss = np.mean(postnet_losses)\n",
|
||||
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
|
||||
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
|
||||
"\n",
|
||||
"# Close the log file\n",
|
||||
"log_file.close()\n",
|
||||
"\n",
|
||||
"# For wavernn\n",
|
||||
"if not DRY_RUN:\n",
|
||||
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
|
||||
"\n",
|
||||
"# For pwgan\n",
|
||||
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
|
||||
" for data in metadata:\n",
|
||||
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
|
||||
"\n",
|
||||
"# Print mean losses\n",
|
||||
"print(f\"Mean Loss: {mean_loss}\")\n",
|
||||
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue