#!/usr/bin/env python3 """Extract Mel spectrograms with teacher forcing.""" import argparse import logging import os import sys import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm from trainer.generic_utils import count_parameters from TTS.config import load_config from TTS.tts.datasets import TTSDataset, load_tts_samples from TTS.tts.models import setup_model from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import quantize from TTS.utils.generic_utils import ConsoleFormatter, setup_logger use_cuda = torch.cuda.is_available() def setup_loader(ap, r): tokenizer, _ = TTSTokenizer.init_from_config(c) dataset = TTSDataset( outputs_per_step=r, compute_linear_spec=False, samples=meta_data, tokenizer=tokenizer, ap=ap, batch_group_size=0, min_text_len=c.min_text_len, max_text_len=c.max_text_len, min_audio_len=c.min_audio_len, max_audio_len=c.max_audio_len, phoneme_cache_path=c.phoneme_cache_path, precompute_num_workers=0, use_noise_augment=False, speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. dataset.compute_input_seq(c.num_loader_workers) dataset.preprocess_samples() loader = DataLoader( dataset, batch_size=c.batch_size, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, sampler=None, num_workers=c.num_loader_workers, pin_memory=False, ) return loader def set_filename(wav_path, out_path): wav_file = os.path.basename(wav_path) file_name = wav_file.split(".")[0] os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) wavq_path = os.path.join(out_path, "quant", file_name) mel_path = os.path.join(out_path, "mel", file_name) wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav") wav_path = os.path.join(out_path, "wav", file_name + ".wav") return file_name, wavq_path, mel_path, wav_gl_path, wav_path def format_data(data): # setup input data text_input = data["token_id"] text_lengths = data["token_id_lengths"] mel_input = data["mel"] mel_lengths = data["mel_lengths"] item_idx = data["item_idxs"] d_vectors = data["d_vectors"] speaker_ids = data["speaker_ids"] attn_mask = data["attns"] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) # dispatch data to GPU if use_cuda: text_input = text_input.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) if d_vectors is not None: d_vectors = d_vectors.cuda(non_blocking=True) if attn_mask is not None: attn_mask = attn_mask.cuda(non_blocking=True) return ( text_input, text_lengths, mel_input, mel_lengths, speaker_ids, d_vectors, avg_text_length, avg_spec_length, attn_mask, item_idx, ) @torch.no_grad() def inference( model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, speaker_ids=None, d_vectors=None, ): if model_name == "glow_tts": speaker_c = None if speaker_ids is not None: speaker_c = speaker_ids elif d_vectors is not None: speaker_c = d_vectors outputs = model.inference_with_MAS( text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, ) model_output = outputs["model_outputs"] model_output = model_output.detach().cpu().numpy() elif "tacotron" in model_name: aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input) postnet_outputs = outputs["model_outputs"] # normalize tacotron output if model_name == "tacotron": mel_specs = [] postnet_outputs = postnet_outputs.data.cpu().numpy() for b in range(postnet_outputs.shape[0]): postnet_output = postnet_outputs[b] mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T)) model_output = torch.stack(mel_specs).cpu().numpy() elif model_name == "tacotron2": model_output = postnet_outputs.detach().cpu().numpy() return model_output def extract_spectrograms( data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt" ): model.eval() export_metadata = [] for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): # format data ( text_input, text_lengths, mel_input, mel_lengths, speaker_ids, d_vectors, _, _, _, item_idx, ) = format_data(data) model_output = inference( c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, speaker_ids, d_vectors, ) for idx in range(text_input.shape[0]): wav_file_path = item_idx[idx] wav = ap.load_wav(wav_file_path) _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) # quantize and save wav if quantize_bits > 0: wavq = quantize(wav, quantize_bits) np.save(wavq_path, wavq) # save TTS mel mel = model_output[idx] mel_length = mel_lengths[idx] mel = mel[:mel_length, :].T np.save(mel_path, mel) export_metadata.append([wav_file_path, mel_path]) if save_audio: ap.save_wav(wav, wav_path) if debug: print("Audio for debug saved at:", wav_gl_path) wav = ap.inv_melspectrogram(mel) ap.save_wav(wav, wav_gl_path) with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f: for data in export_metadata: f.write(f"{data[0]}|{data[1]+'.npy'}\n") def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global meta_data, speaker_manager # Audio processor ap = AudioProcessor(**c.audio) # load data instances meta_data_train, meta_data_eval = load_tts_samples( c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size ) # use eval and training partitions meta_data = meta_data_train + meta_data_eval # init speaker manager if c.use_speaker_embedding: speaker_manager = SpeakerManager(data_items=meta_data) elif c.use_d_vector_file: speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file) else: speaker_manager = None # setup model model = setup_model(c) # restore model model.load_checkpoint(c, args.checkpoint_path, eval=True) if use_cuda: model.cuda() num_params = count_parameters(model) print("\n > Model has {} parameters".format(num_params), flush=True) # set r r = 1 if c.model.lower() == "glow_tts" else model.decoder.r own_loader = setup_loader(ap, r) extract_spectrograms( own_loader, model, ap, args.output_path, quantize_bits=args.quantize_bits, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt", ) if __name__ == "__main__": setup_logger("TTS", level=logging.INFO, stream=sys.stdout, formatter=ConsoleFormatter()) parser = argparse.ArgumentParser() parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero") parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True) args = parser.parse_args() c = load_config(args.config_path) c.audio.trim_silence = False main(args)