import argparse import importlib import os from argparse import RawTextHelpFormatter import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model from TTS.tts.utils.io import load_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config if __name__ == "__main__": # pylint: disable=bad-option-value parser = argparse.ArgumentParser( description="""Extract attention masks from trained Tacotron/Tacotron2 models. These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" """Each attention mask is written to the same path as the input wav file with ".npy" file extension. (e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" """ Example run: CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset_metafile /root/LJSpeech-1.1/metadata.csv --data_path /root/LJSpeech-1.1/ --batch_size 32 --dataset ljspeech --use_cuda True """, formatter_class=RawTextHelpFormatter, ) parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") parser.add_argument( "--config_path", type=str, required=True, help="Path to Tacotron/Tacotron2 config file.", ) parser.add_argument( "--dataset", type=str, default="", required=True, help="Target dataset processor name from TTS.tts.dataset.preprocess.", ) parser.add_argument( "--dataset_metafile", type=str, default="", required=True, help="Dataset metafile inclusing file paths with transcripts.", ) parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") parser.add_argument( "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." ) args = parser.parse_args() C = load_config(args.config_path) ap = AudioProcessor(**C.audio) # if the vocabulary was passed, replace the default if "characters" in C.keys(): symbols, phonemes = make_symbols(**C.characters) # load the model num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker model = setup_model(C) model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) model.eval() # data loader preprocessor = importlib.import_module("TTS.tts.datasets.formatters") preprocessor = getattr(preprocessor, args.dataset) meta_data = preprocessor(args.data_path, args.dataset_metafile) dataset = TTSDataset( model.decoder.r, C.text_cleaner, compute_linear_spec=False, ap=ap, meta_data=meta_data, characters=C.characters if "characters" in C.keys() else None, add_blank=C["add_blank"] if "add_blank" in C.keys() else False, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, phoneme_language=C.phoneme_language, enable_eos_bos=C.enable_eos_bos_chars, ) dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) loader = DataLoader( dataset, batch_size=args.batch_size, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False, ) # compute attentions file_paths = [] with torch.no_grad(): for data in tqdm(loader): # setup input data text_input = data[0] text_lengths = data[1] linear_input = data[3] mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] item_idxs = data[7] # dispatch data to GPU if args.use_cuda: text_input = text_input.cuda() text_lengths = text_lengths.cuda() mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input) alignments = alignments.detach() for idx, alignment in enumerate(alignments): item_idx = item_idxs[idx] # interpolate if r > 1 alignment = ( torch.nn.functional.interpolate( alignment.transpose(0, 1).unsqueeze(0), size=None, scale_factor=model.decoder.r, mode="nearest", align_corners=None, recompute_scale_factor=None, ) .squeeze(0) .transpose(0, 1) ) # remove paddings alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() # set file paths wav_file_name = os.path.basename(item_idx) align_file_name = os.path.splitext(wav_file_name)[0] + ".npy" file_path = item_idx.replace(wav_file_name, align_file_name) # save output file_paths.append([item_idx, file_path]) np.save(file_path, alignment) # ourput metafile metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") with open(metafile, "w", encoding="utf-8") as f: for p in file_paths: f.write(f"{p[0]}|{p[1]}\n") print(f" >> Metafile created: {metafile}")