mirror of https://github.com/coqui-ai/TTS.git
11 KiB
11 KiB
None
<html lang="en">
<head>
</head>
</html>
This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training.
In [ ]:
%load_ext autoreload %autoreload 2 import os import sys import torch import importlib import numpy as np from tqdm import tqdm as tqdm from torch.utils.data import DataLoader from mozilla_voice_tts.tts.datasets.TTSDataset import MyDataset from mozilla_voice_tts.tts.layers.losses import L1LossMasked from mozilla_voice_tts.tts.utils.audio import AudioProcessor from mozilla_voice_tts.tts.utils.visual import plot_spectrogram from mozilla_voice_tts.tts.utils.generic_utils import load_config, setup_model, sequence_mask from mozilla_voice_tts.tts.utils.text.symbols import make_symbols, symbols, phonemes %matplotlib inline import os os.environ['CUDA_VISIBLE_DEVICES']='0'
In [ ]:
def set_filename(wav_path, out_path): wav_file = os.path.basename(wav_path) file_name = wav_file.split('.')[0] os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) wavq_path = os.path.join(out_path, "quant", file_name) mel_path = os.path.join(out_path, "mel", file_name) wav_path = os.path.join(out_path, "wav_gl", file_name) return file_name, wavq_path, mel_path, wav_path
In [ ]:
OUT_PATH = "/home/erogol/Data/LJSpeech-1.1/ljspeech-March-17-2020_01+16AM-871588c/" DATA_PATH = "/home/erogol/Data/LJSpeech-1.1/" DATASET = "ljspeech" METADATA_FILE = "metadata.csv" CONFIG_PATH = "/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/config.json" MODEL_FILE = "/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/checkpoint_420000.pth.tar" BATCH_SIZE = 32 QUANTIZED_WAV = False QUANTIZE_BIT = 9 DRY_RUN = False # if False, does not generate output files, only computes loss and visuals. use_cuda = torch.cuda.is_available() print(" > CUDA enabled: ", use_cuda) C = load_config(CONFIG_PATH) C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)
In [ ]:
# if the vocabulary was passed, replace the default if 'characters' in C.keys(): symbols, phonemes = make_symbols(**C.characters) # load the model num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: multiple speaker model = setup_model(num_chars, num_speakers=0, c=C) checkpoint = torch.load(MODEL_FILE) model.load_state_dict(checkpoint['model']) print(checkpoint['step']) model.eval() model.decoder.set_r(checkpoint['r']) if use_cuda: model = model.cuda()
In [ ]:
preprocessor = importlib.import_module('mozilla_voice_tts.tts.datasets.preprocess') preprocessor = getattr(preprocessor, DATASET.lower()) meta_data = preprocessor(DATA_PATH,METADATA_FILE) dataset = MyDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars) loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)
Generate model outputs¶
In [ ]:
import pickle file_idxs = [] metadata = [] losses = [] postnet_losses = [] criterion = L1LossMasked(seq_len_norm=C.seq_len_norm) with torch.no_grad(): for data in tqdm(loader): # setup input data text_input = data[0] text_lengths = data[1] linear_input = data[3] mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] item_idx = data[7] # dispatch data to GPU if use_cuda: text_input = text_input.cuda() text_lengths = text_lengths.cuda() mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() mask = sequence_mask(text_lengths) mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input) # compute loss loss = criterion(mel_outputs, mel_input, mel_lengths) loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths) losses.append(loss.item()) postnet_losses.append(loss_postnet.item()) # compute mel specs from linear spec if model is Tacotron if C.model == "Tacotron": mel_specs = [] postnet_outputs = postnet_outputs.data.cpu().numpy() for b in range(postnet_outputs.shape[0]): postnet_output = postnet_outputs[b] mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda()) postnet_outputs = torch.stack(mel_specs) elif C.model == "Tacotron2": postnet_outputs = postnet_outputs.detach().cpu().numpy() alignments = alignments.detach().cpu().numpy() if not DRY_RUN: for idx in range(text_input.shape[0]): wav_file_path = item_idx[idx] wav = ap.load_wav(wav_file_path) file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH) file_idxs.append(file_name) # quantize and save wav if QUANTIZED_WAV: wavq = ap.quantize(wav) np.save(wavq_path, wavq) # save TTS mel mel = postnet_outputs[idx] mel_length = mel_lengths[idx] mel = mel[:mel_length, :].T np.save(mel_path, mel) metadata.append([wav_file_path, mel_path]) # for wavernn if not DRY_RUN: pickle.dump(file_idxs, open(OUT_PATH+"/dataset_ids.pkl", "wb")) # for pwgan with open(os.path.join(OUT_PATH, "metadata.txt"), "w") as f: for data in metadata: f.write(f"{data[0]}|{data[1]+'.npy'}\n") print(np.mean(losses)) print(np.mean(postnet_losses))
In [ ]:
# for pwgan with open(os.path.join(OUT_PATH, "metadata.txt"), "w") as f: for data in metadata: f.write(f"{data[0]}|{data[1]+'.npy'}\n")
Sanity Check¶
In [ ]:
idx = 1 ap.melspectrogram(ap.load_wav(item_idx[idx])).shape
In [ ]:
import soundfile as sf wav, sr = sf.read(item_idx[idx]) mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :] mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy() mel_truth = ap.melspectrogram(wav) print(mel_truth.shape)
In [ ]:
# plot posnet output plot_spectrogram(mel_postnet, ap); print(mel_postnet[:mel_lengths[idx], :].shape)
In [ ]:
# plot decoder output plot_spectrogram(mel_decoder, ap); print(mel_decoder.shape)
In [ ]:
# plot GT specgrogram print(mel_truth.shape) plot_spectrogram(mel_truth.T, ap);
In [ ]:
# postnet, decoder diff from matplotlib import pylab as plt mel_diff = mel_decoder - mel_postnet plt.figure(figsize=(16, 10)) plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect="auto", origin="lower"); plt.colorbar() plt.tight_layout()
In [ ]:
# PLOT GT SPECTROGRAM diff from matplotlib import pylab as plt mel_diff2 = mel_truth.T - mel_decoder plt.figure(figsize=(16, 10)) plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower"); plt.colorbar() plt.tight_layout()
In [ ]:
# PLOT GT SPECTROGRAM diff from matplotlib import pylab as plt mel = postnet_outputs[idx] mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]] plt.figure(figsize=(16, 10)) plt.imshow(abs(mel_diff2).T,aspect="auto", origin="lower"); plt.colorbar() plt.tight_layout()