From 0f19f8c911bb367c936b949d3940cc0c23f97767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 12 Jul 2021 12:28:10 +0200 Subject: [PATCH] Fix `compute_attention_masks.py` --- TTS/bin/compute_attention_masks.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 3a5c067e..7de3989d 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -8,12 +8,12 @@ import torch from torch.utils.data import DataLoader from tqdm import tqdm +from TTS.config import load_config from TTS.tts.datasets.TTSDataset import TTSDataset from TTS.tts.models import setup_model -from TTS.tts.utils.io import load_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_config +from TTS.utils.io import load_checkpoint if __name__ == "__main__": # pylint: disable=bad-option-value @@ -27,7 +27,7 @@ Example run: CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json - --dataset_metafile /root/LJSpeech-1.1/metadata.csv + --dataset_metafile metadata.csv --data_path /root/LJSpeech-1.1/ --batch_size 32 --dataset ljspeech @@ -76,8 +76,7 @@ Example run: num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker model = setup_model(C) - model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) - model.eval() + model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) # data loader preprocessor = importlib.import_module("TTS.tts.datasets.formatters") @@ -127,9 +126,9 @@ Example run: mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() - mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input) + model_outputs = model.forward(text_input, text_lengths, mel_input) - alignments = alignments.detach() + alignments = model_outputs["alignments"].detach() for idx, alignment in enumerate(alignments): item_idx = item_idxs[idx] # interpolate if r > 1