mirror of https://github.com/coqui-ai/TTS.git
Fix `compute_attention_masks.py`
This commit is contained in:
parent
994f2be2c1
commit
0f19f8c911
|
@ -8,12 +8,12 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||||
from TTS.tts.models import setup_model
|
from TTS.tts.models import setup_model
|
||||||
from TTS.tts.utils.io import load_checkpoint
|
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_checkpoint
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# pylint: disable=bad-option-value
|
# pylint: disable=bad-option-value
|
||||||
|
@ -27,7 +27,7 @@ Example run:
|
||||||
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
||||||
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
||||||
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
||||||
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
|
--dataset_metafile metadata.csv
|
||||||
--data_path /root/LJSpeech-1.1/
|
--data_path /root/LJSpeech-1.1/
|
||||||
--batch_size 32
|
--batch_size 32
|
||||||
--dataset ljspeech
|
--dataset ljspeech
|
||||||
|
@ -76,8 +76,7 @@ Example run:
|
||||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||||
# TODO: handle multi-speaker
|
# TODO: handle multi-speaker
|
||||||
model = setup_model(C)
|
model = setup_model(C)
|
||||||
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
# data loader
|
# data loader
|
||||||
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
||||||
|
@ -127,9 +126,9 @@ Example run:
|
||||||
mel_input = mel_input.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths = mel_lengths.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
|
|
||||||
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)
|
model_outputs = model.forward(text_input, text_lengths, mel_input)
|
||||||
|
|
||||||
alignments = alignments.detach()
|
alignments = model_outputs["alignments"].detach()
|
||||||
for idx, alignment in enumerate(alignments):
|
for idx, alignment in enumerate(alignments):
|
||||||
item_idx = item_idxs[idx]
|
item_idx = item_idxs[idx]
|
||||||
# interpolate if r > 1
|
# interpolate if r > 1
|
||||||
|
|
Loading…
Reference in New Issue