coqui-tts/notebooks/ExtractTTSpectrogram.ipynb

12 KiB

None <html lang="en"> <head> </head>

This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training.

In [ ]:
import importlib
import os
import pickle

import numpy as np
import soundfile as sf
import torch
from matplotlib import pylab as plt
from torch.utils.data import DataLoader
from tqdm import tqdm

from TTS.config import load_config
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.layers.losses import L1LossMasked
from TTS.tts.models import setup_model
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import quantize

%matplotlib inline

# Configure CUDA visibility
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
In [ ]:
# Function to create directories and file names
def set_filename(wav_path, out_path):
    wav_file = os.path.basename(wav_path)
    file_name = wav_file.split('.')[0]
    os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
    os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
    wavq_path = os.path.join(out_path, "quant", file_name)
    mel_path = os.path.join(out_path, "mel", file_name)
    return file_name, wavq_path, mel_path
In [ ]:
# Paths and configurations
OUT_PATH = "/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/"
DATA_PATH = "/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/"
PHONEME_CACHE_PATH = "/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache"
DATASET = "ljspeech"
METADATA_FILE = "metadata.csv"
CONFIG_PATH = "/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json"
MODEL_FILE = "/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth"
BATCH_SIZE = 32

QUANTIZE_BITS = 0  # if non-zero, quantize wav files with the given number of bits
DRY_RUN = False   # if False, does not generate output files, only computes loss and visuals.

# Check CUDA availability
use_cuda = torch.cuda.is_available()
print(" > CUDA enabled: ", use_cuda)

# Load the configuration
dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)
C = load_config(CONFIG_PATH)
C.audio['do_trim_silence'] = False  # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files
ap = AudioProcessor(**C.audio)
In [ ]:
# Initialize the tokenizer
tokenizer, C = TTSTokenizer.init_from_config(C)

# Load the model
# TODO: multiple speakers
model = setup_model(C)
model.load_checkpoint(C, MODEL_FILE, eval=True)
In [ ]:
# Load data instances
meta_data_train, meta_data_eval = load_tts_samples(dataset_config)
meta_data = meta_data_train + meta_data_eval

dataset = TTSDataset(
    outputs_per_step=C["r"],
    compute_linear_spec=False,
    ap=ap,
    samples=meta_data,
    tokenizer=tokenizer,
    phoneme_cache_path=PHONEME_CACHE_PATH,
)
loader = DataLoader(
    dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False
)

Generate model outputs

In [ ]:
# Initialize lists for storing results
file_idxs = []
metadata = []
losses = []
postnet_losses = []
criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)

# Start processing with a progress bar
log_file_path = os.path.join(OUT_PATH, "log.txt")
with torch.no_grad() and open(log_file_path, "w") as log_file:
    for data in tqdm(loader, desc="Processing"):
        try:
            # dispatch data to GPU
            if use_cuda:
                data["token_id"] = data["token_id"].cuda()
                data["token_id_lengths"] = data["token_id_lengths"].cuda()
                data["mel"] = data["mel"].cuda()
                data["mel_lengths"] = data["mel_lengths"].cuda()

            mask = sequence_mask(data["token_id_lengths"])
            outputs = model.forward(data["token_id"], data["token_id_lengths"], data["mel"])
            mel_outputs = outputs["decoder_outputs"]
            postnet_outputs = outputs["model_outputs"]

            # compute loss
            loss = criterion(mel_outputs, data["mel"], data["mel_lengths"])
            loss_postnet = criterion(postnet_outputs, data["mel"], data["mel_lengths"])
            losses.append(loss.item())
            postnet_losses.append(loss_postnet.item())

            # compute mel specs from linear spec if the model is Tacotron
            if C.model == "Tacotron":
                mel_specs = []
                postnet_outputs = postnet_outputs.data.cpu().numpy()
                for b in range(postnet_outputs.shape[0]):
                    postnet_output = postnet_outputs[b]
                    mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())
                postnet_outputs = torch.stack(mel_specs)
            elif C.model == "Tacotron2":
                postnet_outputs = postnet_outputs.detach().cpu().numpy()
            alignments = outputs["alignments"].detach().cpu().numpy()

            if not DRY_RUN:
                for idx in range(data["token_id"].shape[0]):
                    wav_file_path = data["item_idxs"][idx]
                    wav = ap.load_wav(wav_file_path)
                    file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)
                    file_idxs.append(file_name)

                    # quantize and save wav
                    if QUANTIZE_BITS > 0:
                        wavq = quantize(wav, QUANTIZE_BITS)
                        np.save(wavq_path, wavq)

                    # save TTS mel
                    mel = postnet_outputs[idx]
                    mel_length = data["mel_lengths"][idx]
                    mel = mel[:mel_length, :].T
                    np.save(mel_path, mel)

                    metadata.append([wav_file_path, mel_path])
        except Exception as e:
            log_file.write(f"Error processing data: {str(e)}\n")

    # Calculate and log mean losses
    mean_loss = np.mean(losses)
    mean_postnet_loss = np.mean(postnet_losses)
    log_file.write(f"Mean Loss: {mean_loss}\n")
    log_file.write(f"Mean Postnet Loss: {mean_postnet_loss}\n")

# For wavernn
if not DRY_RUN:
    pickle.dump(file_idxs, open(os.path.join(OUT_PATH, "dataset_ids.pkl"), "wb"))

# For pwgan
with open(os.path.join(OUT_PATH, "metadata.txt"), "w") as f:
    for wav_file_path, mel_path in metadata:
        f.write(f"{wav_file_path[0]}|{mel_path[1]+'.npy'}\n")

# Print mean losses
print(f"Mean Loss: {mean_loss}")
print(f"Mean Postnet Loss: {mean_postnet_loss}")

Sanity Check

In [ ]:
idx = 1
ap.melspectrogram(ap.load_wav(data["item_idxs"][idx])).shape
In [ ]:
wav, sr = sf.read(data["item_idxs"][idx])
mel_postnet = postnet_outputs[idx][:data["mel_lengths"][idx], :]
mel_decoder = mel_outputs[idx][:data["mel_lengths"][idx], :].detach().cpu().numpy()
mel_truth = ap.melspectrogram(wav)
print(mel_truth.shape)
In [ ]:
# plot posnet output
print(mel_postnet[:data["mel_lengths"][idx], :].shape)
plot_spectrogram(mel_postnet, ap)
In [ ]:
# plot decoder output
print(mel_decoder.shape)
plot_spectrogram(mel_decoder, ap)
In [ ]:
# plot GT specgrogram
print(mel_truth.shape)
plot_spectrogram(mel_truth.T, ap)
In [ ]:
# postnet, decoder diff
mel_diff = mel_decoder - mel_postnet
plt.figure(figsize=(16, 10))
plt.imshow(abs(mel_diff[:data["mel_lengths"][idx],:]).T,aspect="auto", origin="lower")
plt.colorbar()
plt.tight_layout()
In [ ]:
# PLOT GT SPECTROGRAM diff
mel_diff2 = mel_truth.T - mel_decoder
plt.figure(figsize=(16, 10))
plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower")
plt.colorbar()
plt.tight_layout()
In [ ]:
# PLOT GT SPECTROGRAM diff
mel = postnet_outputs[idx]
mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]
plt.figure(figsize=(16, 10))
plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower")
plt.colorbar()
plt.tight_layout()
</html>