#!/usr/bin/env python3 """Extract Mel spectrograms with teacher forcing.""" import os import argparse import numpy as np from tqdm import tqdm import torch from torch.utils.data import DataLoader from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.speakers import parse_speakers from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters use_cuda = torch.cuda.is_available() def setup_loader(ap, r, verbose=False): dataset = MyDataset( r, c.text_cleaner, compute_linear_spec=False, meta_data=meta_data, ap=ap, tp=c.characters if "characters" in c.keys() else None, add_blank=c["add_blank"] if "add_blank" in c.keys() else False, batch_group_size=0, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, use_noise_augment=False, verbose=verbose, speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None, ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. dataset.compute_input_seq(c.num_loader_workers) dataset.sort_items() loader = DataLoader( dataset, batch_size=c.batch_size, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, sampler=None, num_workers=c.num_loader_workers, pin_memory=False, ) return loader def set_filename(wav_path, out_path): wav_file = os.path.basename(wav_path) file_name = wav_file.split('.')[0] os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) wavq_path = os.path.join(out_path, "quant", file_name) mel_path = os.path.join(out_path, "mel", file_name) wav_gl_path = os.path.join(out_path, "wav_gl", file_name+'.wav') wav_path = os.path.join(out_path, "wav", file_name+'.wav') return file_name, wavq_path, mel_path, wav_gl_path, wav_path def format_data(data): # setup input data text_input = data[0] text_lengths = data[1] speaker_names = data[2] mel_input = data[4] mel_lengths = data[5] item_idx = data[7] attn_mask = data[9] avg_text_length = torch.mean(text_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float()) if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: speaker_embeddings = data[8] speaker_ids = None else: speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_ids = torch.LongTensor(speaker_ids) speaker_embeddings = None else: speaker_embeddings = None speaker_ids = None # dispatch data to GPU if use_cuda: text_input = text_input.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) if speaker_embeddings is not None: speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) if attn_mask is not None: attn_mask = attn_mask.cuda(non_blocking=True) return ( text_input, text_lengths, mel_input, mel_lengths, speaker_ids, speaker_embeddings, avg_text_length, avg_spec_length, attn_mask, item_idx, ) @torch.no_grad() def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None): if model_name == "glow_tts": mel_input = mel_input.permute(0, 2, 1) # B x D x T speaker_c = None if speaker_ids is not None: speaker_c = speaker_ids elif speaker_embeddings is not None: speaker_c = speaker_embeddings model_output, _, _, _, _, _, _ = model.inference_with_MAS( text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c ) model_output = model_output.transpose(1, 2).detach().cpu().numpy() elif "tacotron" in model_name: if c.bidirectional_decoder or c.double_decoder_consistency: ( _, postnet_outputs, _, _, _, _, ) = model( text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings ) else: _, postnet_outputs, _, _ = model( text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings ) # normalize tacotron output if model_name == "tacotron": mel_specs = [] postnet_outputs = postnet_outputs.data.cpu().numpy() for b in range(postnet_outputs.shape[0]): postnet_output = postnet_outputs[b] mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda()) model_output = torch.stack(mel_specs) elif model_name == "tacotron2": model_output = postnet_outputs.detach().cpu().numpy() return model_output def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"): model.eval() export_metadata = [] for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): # format data ( text_input, text_lengths, mel_input, mel_lengths, speaker_ids, speaker_embeddings, _, _, attn_mask, item_idx, ) = format_data(data) model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings) for idx in range(text_input.shape[0]): wav_file_path = item_idx[idx] wav = ap.load_wav(wav_file_path) _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) # quantize and save wav if quantized_wav: wavq = ap.quantize(wav) np.save(wavq_path, wavq) # save TTS mel mel = model_output[idx] mel_length = mel_lengths[idx] mel = mel[:mel_length, :].T np.save(mel_path, mel) export_metadata.append([wav_file_path, mel_path]) if save_audio: ap.save_wav(wav, wav_path) if debug: print("Audio for debug saved at:", wav_gl_path) wav = ap.inv_melspectrogram(mel) ap.save_wav(wav, wav_gl_path) with open(os.path.join(output_path, metada_name), "w") as f: for data in export_metadata: f.write(f"{data[0]}|{data[1]+'.npy'}\n") def main(args): # pylint: disable=redefined-outer-name # pylint: disable=global-variable-undefined global meta_data, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # set model characters model_characters = phonemes if c.use_phonemes else symbols num_chars = len(model_characters) # load data instances meta_data_train, meta_data_eval = load_meta_data(c.datasets) # use eval and training partitions meta_data = meta_data_train + meta_data_eval # parse speakers num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, None) # setup model model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) # restore model checkpoint = torch.load(args.checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint["model"]) if use_cuda: model.cuda() num_params = count_parameters(model) print("\n > Model has {} parameters".format(num_params), flush=True) # set r r = 1 if c.model.lower() == "glow_tts" else model.decoder.r own_loader = setup_loader(ap, r, verbose=True) extract_spectrograms(own_loader, model, ap, args.output_path, quantized_wav=args.quantized, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( '--config_path', type=str, help='Path to config file for training.', required=True) parser.add_argument( '--checkpoint_path', type=str, help='Model file to be restored.', required=True) parser.add_argument( '--output_path', type=str, help='Path to save mel specs', required=True) parser.add_argument('--debug', default=False, action='store_true', help='Save audio files for debug') parser.add_argument('--save_audio', default=False, action='store_true', help='Save audio files') parser.add_argument('--quantized', action='store_true', help='Save quantized audio files') args = parser.parse_args() c = load_config(args.config_path) main(args)