mirror of https://github.com/coqui-ai/TTS.git
update compute_attention_masks.py
This commit is contained in:
parent
0a9767afd7
commit
7beaacc55b
|
@ -1,17 +1,3 @@
|
|||
"""Compute attention masks from pre-trained Tacotron or Tacotron2 models.
|
||||
Sample run on LJSpeech dataset.
|
||||
|
||||
>>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \
|
||||
--model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \
|
||||
--config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \
|
||||
--dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \
|
||||
--data_path /home/erogol/Data/LJSpeech-1.1/ \
|
||||
--batch_size 16 \
|
||||
--use_cuda true
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
|
@ -20,6 +6,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from argparse import RawTextHelpFormatter
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import load_checkpoint
|
||||
|
@ -30,40 +17,52 @@ from TTS.utils.io import load_config
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract attention masks from trained Tacotron models.')
|
||||
description='''Extract attention masks from trained Tacotron/Tacotron2 models.
|
||||
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n'''
|
||||
|
||||
'''Each attention mask is written to the same path as the input wav file with ".npy" file extension.
|
||||
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n'''
|
||||
|
||||
'''
|
||||
Example run:
|
||||
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
||||
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
||||
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
||||
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
|
||||
--data_path /root/LJSpeech-1.1/
|
||||
--batch_size 32
|
||||
--dataset ljspeech
|
||||
--use_cuda True
|
||||
''',
|
||||
formatter_class=RawTextHelpFormatter
|
||||
)
|
||||
parser.add_argument('--model_path',
|
||||
type=str,
|
||||
help='Path to Tacotron or Tacotron2 model file ')
|
||||
required=True,
|
||||
help='Path to Tacotron/Tacotron2 model file ')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to config file for training.',
|
||||
help='Path to Tacotron/Tacotron2 config file.',
|
||||
)
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
default='',
|
||||
help='Dataset from TTS.tts.dataset.preprocess.')
|
||||
required=True,
|
||||
help='Target dataset processor name from TTS.tts.dataset.preprocess.')
|
||||
|
||||
parser.add_argument(
|
||||
'--dataset_metafile',
|
||||
type=str,
|
||||
default='',
|
||||
required=True,
|
||||
help='Dataset metafile inclusing file paths with transcripts.')
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='Defines the data path. It overwrites config.json.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path for training outputs.',
|
||||
default='')
|
||||
parser.add_argument('--output_folder',
|
||||
type=str,
|
||||
default='',
|
||||
help='folder name for training outputs.')
|
||||
|
||||
parser.add_argument('--use_cuda',
|
||||
type=bool,
|
||||
default=False,
|
||||
|
@ -148,10 +147,8 @@ if __name__ == '__main__':
|
|||
mode='nearest',
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None).squeeze(0).transpose(0, 1)
|
||||
|
||||
# remove paddings
|
||||
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
|
||||
|
||||
# set file paths
|
||||
wav_file_name = os.path.basename(item_idx)
|
||||
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
|
||||
|
@ -160,7 +157,7 @@ if __name__ == '__main__':
|
|||
file_paths.append([item_idx, file_path])
|
||||
np.save(file_path, alignment)
|
||||
|
||||
# ourpur metafile
|
||||
# ourput metafile
|
||||
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
||||
|
||||
with open(metafile, "w") as f:
|
||||
|
|
Loading…
Reference in New Issue