Fix `compute_attention_masks.py`

This commit is contained in:
Eren Gölge 2021-07-12 12:28:10 +02:00
parent 994f2be2c1
commit 0f19f8c911
1 changed files with 6 additions and 7 deletions

View File

@ -8,12 +8,12 @@ import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.config import load_config
from TTS.tts.datasets.TTSDataset import TTSDataset
from TTS.tts.models import setup_model
from TTS.tts.utils.io import load_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.utils.io import load_checkpoint
if __name__ == "__main__":
# pylint: disable=bad-option-value
@ -27,7 +27,7 @@ Example run:
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
--dataset_metafile metadata.csv
--data_path /root/LJSpeech-1.1/
--batch_size 32
--dataset ljspeech
@ -76,8 +76,7 @@ Example run:
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
# TODO: handle multi-speaker
model = setup_model(C)
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
model.eval()
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
# data loader
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
@ -127,9 +126,9 @@ Example run:
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)
model_outputs = model.forward(text_input, text_lengths, mel_input)
alignments = alignments.detach()
alignments = model_outputs["alignments"].detach()
for idx, alignment in enumerate(alignments):
item_idx = item_idxs[idx]
# interpolate if r > 1