coqui-tts/notebooks/ExtractTTSpectrogram.ipynb

11 KiB

None <html lang="en"> <head> </head>

This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training.

In [ ]:
%load_ext autoreload
%autoreload 2
import os
import sys
import torch
import importlib
import numpy as np
from tqdm import tqdm as tqdm
from torch.utils.data import DataLoader
from mozilla_voice_tts.tts.datasets.TTSDataset import MyDataset
from mozilla_voice_tts.tts.layers.losses import L1LossMasked
from mozilla_voice_tts.tts.utils.audio import AudioProcessor
from mozilla_voice_tts.tts.utils.visual import plot_spectrogram
from mozilla_voice_tts.tts.utils.generic_utils import load_config, setup_model, sequence_mask
from mozilla_voice_tts.tts.utils.text.symbols import make_symbols, symbols, phonemes

%matplotlib inline

import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
In [ ]:
def set_filename(wav_path, out_path):
    wav_file = os.path.basename(wav_path)
    file_name = wav_file.split('.')[0]
    os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
    os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
    os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
    wavq_path = os.path.join(out_path, "quant", file_name)
    mel_path = os.path.join(out_path, "mel", file_name)
    wav_path = os.path.join(out_path, "wav_gl", file_name)
    return file_name, wavq_path, mel_path, wav_path
In [ ]:
OUT_PATH = "/home/erogol/Data/LJSpeech-1.1/ljspeech-March-17-2020_01+16AM-871588c/"
DATA_PATH = "/home/erogol/Data/LJSpeech-1.1/"
DATASET = "ljspeech"
METADATA_FILE = "metadata.csv"
CONFIG_PATH = "/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/config.json"
MODEL_FILE = "/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/checkpoint_420000.pth.tar"
BATCH_SIZE = 32

QUANTIZED_WAV = False
QUANTIZE_BIT = 9
DRY_RUN = False   # if False, does not generate output files, only computes loss and visuals.

use_cuda = torch.cuda.is_available()
print(" > CUDA enabled: ", use_cuda)

C = load_config(CONFIG_PATH)
C.audio['do_trim_silence'] = False  # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files
ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)
In [ ]:
# if the vocabulary was passed, replace the default
if 'characters' in C.keys():
    symbols, phonemes = make_symbols(**C.characters)

# load the model
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
# TODO: multiple speaker
model = setup_model(num_chars, num_speakers=0, c=C)
checkpoint = torch.load(MODEL_FILE)
model.load_state_dict(checkpoint['model'])
print(checkpoint['step'])
model.eval()
model.decoder.set_r(checkpoint['r'])
if use_cuda:
    model = model.cuda()
In [ ]:
preprocessor = importlib.import_module('mozilla_voice_tts.tts.datasets.preprocess')
preprocessor = getattr(preprocessor, DATASET.lower())
meta_data = preprocessor(DATA_PATH,METADATA_FILE)
dataset = MyDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes,  phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)

Generate model outputs

In [ ]:
import pickle

file_idxs = []
metadata = []
losses = []
postnet_losses = []
criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)
with torch.no_grad():
    for data in tqdm(loader):
        # setup input data
        text_input = data[0]
        text_lengths = data[1]
        linear_input = data[3]
        mel_input = data[4]
        mel_lengths = data[5]
        stop_targets = data[6]
        item_idx = data[7]

        # dispatch data to GPU
        if use_cuda:
            text_input = text_input.cuda()
            text_lengths = text_lengths.cuda()
            mel_input = mel_input.cuda()
            mel_lengths = mel_lengths.cuda()

        mask = sequence_mask(text_lengths)
        mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)
        
        # compute loss
        loss = criterion(mel_outputs, mel_input, mel_lengths)
        loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)
        losses.append(loss.item())
        postnet_losses.append(loss_postnet.item())

        # compute mel specs from linear spec if model is Tacotron
        if C.model == "Tacotron":
            mel_specs = []
            postnet_outputs = postnet_outputs.data.cpu().numpy()
            for b in range(postnet_outputs.shape[0]):
                postnet_output = postnet_outputs[b]
                mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())
            postnet_outputs = torch.stack(mel_specs)
        elif C.model == "Tacotron2":
            postnet_outputs = postnet_outputs.detach().cpu().numpy()
        alignments = alignments.detach().cpu().numpy()

        if not DRY_RUN:
            for idx in range(text_input.shape[0]):
                wav_file_path = item_idx[idx]
                wav = ap.load_wav(wav_file_path)
                file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)
                file_idxs.append(file_name)

                # quantize and save wav
                if QUANTIZED_WAV:
                    wavq = ap.quantize(wav)
                    np.save(wavq_path, wavq)

                # save TTS mel
                mel = postnet_outputs[idx]
                mel_length = mel_lengths[idx]
                mel = mel[:mel_length, :].T
                np.save(mel_path, mel)

                metadata.append([wav_file_path, mel_path])

    # for wavernn
    if not DRY_RUN:
        pickle.dump(file_idxs, open(OUT_PATH+"/dataset_ids.pkl", "wb"))      
    
    # for pwgan
    with open(os.path.join(OUT_PATH, "metadata.txt"), "w") as f:
        for data in metadata:
            f.write(f"{data[0]}|{data[1]+'.npy'}\n")

    print(np.mean(losses))
    print(np.mean(postnet_losses))
In [ ]:
# for pwgan
with open(os.path.join(OUT_PATH, "metadata.txt"), "w") as f:
    for data in metadata:
        f.write(f"{data[0]}|{data[1]+'.npy'}\n")

Sanity Check

In [ ]:
idx = 1
ap.melspectrogram(ap.load_wav(item_idx[idx])).shape
In [ ]:
import soundfile as sf
wav, sr = sf.read(item_idx[idx])
mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]
mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()
mel_truth = ap.melspectrogram(wav)
print(mel_truth.shape)
In [ ]:
# plot posnet output
plot_spectrogram(mel_postnet, ap);
print(mel_postnet[:mel_lengths[idx], :].shape)
In [ ]:
# plot decoder output
plot_spectrogram(mel_decoder, ap);
print(mel_decoder.shape)
In [ ]:
# plot GT specgrogram
print(mel_truth.shape)
plot_spectrogram(mel_truth.T, ap);
In [ ]:
# postnet, decoder diff
from matplotlib import pylab as plt
mel_diff = mel_decoder - mel_postnet
plt.figure(figsize=(16, 10))
plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect="auto", origin="lower");
plt.colorbar()
plt.tight_layout()
In [ ]:
# PLOT GT SPECTROGRAM diff
from matplotlib import pylab as plt
mel_diff2 = mel_truth.T - mel_decoder
plt.figure(figsize=(16, 10))
plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower");
plt.colorbar()
plt.tight_layout()
In [ ]:
# PLOT GT SPECTROGRAM diff
from matplotlib import pylab as plt
mel = postnet_outputs[idx]
mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]
plt.figure(figsize=(16, 10))
plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower");
plt.colorbar()
plt.tight_layout()
</html>