mirror of https://github.com/coqui-ai/TTS.git
Fix `compute_attention_masks.py`
This commit is contained in:
parent
994f2be2c1
commit
0f19f8c911
|
@ -8,12 +8,12 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.io import load_checkpoint
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.io import load_checkpoint
|
||||
|
||||
if __name__ == "__main__":
|
||||
# pylint: disable=bad-option-value
|
||||
|
@ -27,7 +27,7 @@ Example run:
|
|||
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
||||
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
||||
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
||||
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
|
||||
--dataset_metafile metadata.csv
|
||||
--data_path /root/LJSpeech-1.1/
|
||||
--batch_size 32
|
||||
--dataset ljspeech
|
||||
|
@ -76,8 +76,7 @@ Example run:
|
|||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
# TODO: handle multi-speaker
|
||||
model = setup_model(C)
|
||||
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
||||
model.eval()
|
||||
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
|
||||
|
||||
# data loader
|
||||
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
||||
|
@ -127,9 +126,9 @@ Example run:
|
|||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
|
||||
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)
|
||||
model_outputs = model.forward(text_input, text_lengths, mel_input)
|
||||
|
||||
alignments = alignments.detach()
|
||||
alignments = model_outputs["alignments"].detach()
|
||||
for idx, alignment in enumerate(alignments):
|
||||
item_idx = item_idxs[idx]
|
||||
# interpolate if r > 1
|
||||
|
|
Loading…
Reference in New Issue