mirror of https://github.com/coqui-ai/TTS.git
12 KiB
12 KiB
None
<html lang="en">
<head>
</head>
</html>
This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training.
In [ ]:
import importlib import os import pickle import numpy as np import soundfile as sf import torch from matplotlib import pylab as plt from torch.utils.data import DataLoader from tqdm import tqdm from TTS.config import load_config from TTS.tts.configs.shared_configs import BaseDatasetConfig from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.layers.losses import L1LossMasked from TTS.tts.models import setup_model from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import quantize %matplotlib inline # Configure CUDA visibility os.environ['CUDA_VISIBLE_DEVICES'] = '2'
In [ ]:
# Function to create directories and file names def set_filename(wav_path, out_path): wav_file = os.path.basename(wav_path) file_name = wav_file.split('.')[0] os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) wavq_path = os.path.join(out_path, "quant", file_name) mel_path = os.path.join(out_path, "mel", file_name) return file_name, wavq_path, mel_path
In [ ]:
# Paths and configurations OUT_PATH = "/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/" DATA_PATH = "/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/" PHONEME_CACHE_PATH = "/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache" DATASET = "ljspeech" METADATA_FILE = "metadata.csv" CONFIG_PATH = "/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json" MODEL_FILE = "/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth" BATCH_SIZE = 32 QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits DRY_RUN = False # if False, does not generate output files, only computes loss and visuals. # Check CUDA availability use_cuda = torch.cuda.is_available() print(" > CUDA enabled: ", use_cuda) # Load the configuration dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH) C = load_config(CONFIG_PATH) C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files ap = AudioProcessor(**C.audio)
In [ ]:
# Initialize the tokenizer tokenizer, C = TTSTokenizer.init_from_config(C) # Load the model # TODO: multiple speakers model = setup_model(C) model.load_checkpoint(C, MODEL_FILE, eval=True)
In [ ]:
# Load data instances meta_data_train, meta_data_eval = load_tts_samples(dataset_config) meta_data = meta_data_train + meta_data_eval dataset = TTSDataset( outputs_per_step=C["r"], compute_linear_spec=False, ap=ap, samples=meta_data, tokenizer=tokenizer, phoneme_cache_path=PHONEME_CACHE_PATH, ) loader = DataLoader( dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False )
Generate model outputs¶
In [ ]:
# Initialize lists for storing results file_idxs = [] metadata = [] losses = [] postnet_losses = [] criterion = L1LossMasked(seq_len_norm=C.seq_len_norm) # Start processing with a progress bar log_file_path = os.path.join(OUT_PATH, "log.txt") with torch.no_grad() and open(log_file_path, "w") as log_file: for data in tqdm(loader, desc="Processing"): try: # dispatch data to GPU if use_cuda: data["token_id"] = data["token_id"].cuda() data["token_id_lengths"] = data["token_id_lengths"].cuda() data["mel"] = data["mel"].cuda() data["mel_lengths"] = data["mel_lengths"].cuda() mask = sequence_mask(data["token_id_lengths"]) outputs = model.forward(data["token_id"], data["token_id_lengths"], data["mel"]) mel_outputs = outputs["decoder_outputs"] postnet_outputs = outputs["model_outputs"] # compute loss loss = criterion(mel_outputs, data["mel"], data["mel_lengths"]) loss_postnet = criterion(postnet_outputs, data["mel"], data["mel_lengths"]) losses.append(loss.item()) postnet_losses.append(loss_postnet.item()) # compute mel specs from linear spec if the model is Tacotron if C.model == "Tacotron": mel_specs = [] postnet_outputs = postnet_outputs.data.cpu().numpy() for b in range(postnet_outputs.shape[0]): postnet_output = postnet_outputs[b] mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda()) postnet_outputs = torch.stack(mel_specs) elif C.model == "Tacotron2": postnet_outputs = postnet_outputs.detach().cpu().numpy() alignments = outputs["alignments"].detach().cpu().numpy() if not DRY_RUN: for idx in range(data["token_id"].shape[0]): wav_file_path = data["item_idxs"][idx] wav = ap.load_wav(wav_file_path) file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH) file_idxs.append(file_name) # quantize and save wav if QUANTIZE_BITS > 0: wavq = quantize(wav, QUANTIZE_BITS) np.save(wavq_path, wavq) # save TTS mel mel = postnet_outputs[idx] mel_length = data["mel_lengths"][idx] mel = mel[:mel_length, :].T np.save(mel_path, mel) metadata.append([wav_file_path, mel_path]) except Exception as e: log_file.write(f"Error processing data: {str(e)}\n") # Calculate and log mean losses mean_loss = np.mean(losses) mean_postnet_loss = np.mean(postnet_losses) log_file.write(f"Mean Loss: {mean_loss}\n") log_file.write(f"Mean Postnet Loss: {mean_postnet_loss}\n") # For wavernn if not DRY_RUN: pickle.dump(file_idxs, open(os.path.join(OUT_PATH, "dataset_ids.pkl"), "wb")) # For pwgan with open(os.path.join(OUT_PATH, "metadata.txt"), "w") as f: for wav_file_path, mel_path in metadata: f.write(f"{wav_file_path[0]}|{mel_path[1]+'.npy'}\n") # Print mean losses print(f"Mean Loss: {mean_loss}") print(f"Mean Postnet Loss: {mean_postnet_loss}")
Sanity Check¶
In [ ]:
idx = 1 ap.melspectrogram(ap.load_wav(data["item_idxs"][idx])).shape
In [ ]:
wav, sr = sf.read(data["item_idxs"][idx]) mel_postnet = postnet_outputs[idx][:data["mel_lengths"][idx], :] mel_decoder = mel_outputs[idx][:data["mel_lengths"][idx], :].detach().cpu().numpy() mel_truth = ap.melspectrogram(wav) print(mel_truth.shape)
In [ ]:
# plot posnet output print(mel_postnet[:data["mel_lengths"][idx], :].shape) plot_spectrogram(mel_postnet, ap)
In [ ]:
# plot decoder output print(mel_decoder.shape) plot_spectrogram(mel_decoder, ap)
In [ ]:
# plot GT specgrogram print(mel_truth.shape) plot_spectrogram(mel_truth.T, ap)
In [ ]:
# postnet, decoder diff mel_diff = mel_decoder - mel_postnet plt.figure(figsize=(16, 10)) plt.imshow(abs(mel_diff[:data["mel_lengths"][idx],:]).T,aspect="auto", origin="lower") plt.colorbar() plt.tight_layout()
In [ ]:
# PLOT GT SPECTROGRAM diff mel_diff2 = mel_truth.T - mel_decoder plt.figure(figsize=(16, 10)) plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower") plt.colorbar() plt.tight_layout()
In [ ]:
# PLOT GT SPECTROGRAM diff mel = postnet_outputs[idx] mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]] plt.figure(figsize=(16, 10)) plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower") plt.colorbar() plt.tight_layout()