"""Compute attention masks from pre-trained Tacotron or Tacotron2 models. Sample run on LJSpeech dataset. >>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \ --model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \ --config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \ --dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \ --data_path /home/erogol/Data/LJSpeech-1.1/ \ --batch_size 16 \ --use_cuda true """ import argparse import glob import importlib import os import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.utils.generic_utils import sequence_mask, setup_model from TTS.tts.utils.io import load_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config if __name__ == '__main__': parser = argparse.ArgumentParser( description='Extract attention masks from trained Tacotron models.') parser.add_argument('--model_path', type=str, help='Path to Tacotron or Tacotron2 model file ') parser.add_argument( '--config_path', type=str, required=True, help='Path to config file for training.', ) parser.add_argument('--dataset', type=str, default='', help='Dataset from TTS.tts.dataset.preprocess.') parser.add_argument( '--dataset_metafile', type=str, default='', help='Dataset metafile inclusing file paths with transcripts.') parser.add_argument( '--data_path', type=str, default='', help='Defines the data path. It overwrites config.json.') parser.add_argument('--output_path', type=str, help='path for training outputs.', default='') parser.add_argument('--output_folder', type=str, default='', help='folder name for training outputs.') parser.add_argument('--use_cuda', type=bool, default=False, help="enable/disable cuda.") parser.add_argument( '--batch_size', default=16, type=int, help='Batch size for the model. Use batch_size=1 if you have no CUDA.') args = parser.parse_args() C = load_config(args.config_path) ap = AudioProcessor(**C.audio) # if the vocabulary was passed, replace the default if 'characters' in C.keys(): symbols, phonemes = make_symbols(**C.characters) # load the model num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker model = setup_model(num_chars, num_speakers=0, c=C) model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) model.eval() # data loader preprocessor = importlib.import_module('TTS.tts.datasets.preprocess') preprocessor = getattr(preprocessor, args.dataset) meta_data = preprocessor(args.data_path, args.dataset_metafile) dataset = MyDataset(model.decoder.r, C.text_cleaner, compute_linear_spec=False, ap=ap, meta_data=meta_data, tp=C.characters if 'characters' in C.keys() else None, add_blank=c['add_blank'] if 'add_blank' in C.keys() else False, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, phoneme_language=C.phoneme_language, enable_eos_bos=C.enable_eos_bos_chars) dataset.sort_items() loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False) # compute attentions file_paths = [] with torch.no_grad(): for data in tqdm(loader): # setup input data text_input = data[0] text_lengths = data[1] linear_input = data[3] mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] item_idxs = data[7] # dispatch data to GPU if args.use_cuda: text_input = text_input.cuda() text_lengths = text_lengths.cuda() mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward( text_input, text_lengths, mel_input) alignments = alignments.detach() for idx, alignment in enumerate(alignments): item_idx = item_idxs[idx] # interpolate if r > 1 alignment = torch.nn.functional.interpolate( alignment.transpose(0, 1).unsqueeze(0), size=None, scale_factor=model.decoder.r, mode='nearest', align_corners=None, recompute_scale_factor=None).squeeze(0).transpose(0, 1) # remove paddings alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy() # set file paths wav_file_name = os.path.basename(item_idx) align_file_name = os.path.splitext(wav_file_name)[0] + '.npy' file_path = item_idx.replace(wav_file_name, align_file_name) # save output file_paths.append([item_idx, file_path]) np.save(file_path, alignment) # ourpur metafile metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") with open(metafile, "w") as f: for p in file_paths: f.write(f"{p[0]}|{p[1]}\n") print(f" >> Metafile created: {metafile}")