mirror of https://github.com/coqui-ai/TTS.git
update compute_attention_masks.py
This commit is contained in:
parent
0a9767afd7
commit
7beaacc55b
|
@ -1,17 +1,3 @@
|
||||||
"""Compute attention masks from pre-trained Tacotron or Tacotron2 models.
|
|
||||||
Sample run on LJSpeech dataset.
|
|
||||||
|
|
||||||
>>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \
|
|
||||||
--model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \
|
|
||||||
--config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \
|
|
||||||
--dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \
|
|
||||||
--data_path /home/erogol/Data/LJSpeech-1.1/ \
|
|
||||||
--batch_size 16 \
|
|
||||||
--use_cuda true
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
@ -20,6 +6,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from argparse import RawTextHelpFormatter
|
||||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import load_checkpoint
|
from TTS.tts.utils.io import load_checkpoint
|
||||||
|
@ -30,40 +17,52 @@ from TTS.utils.io import load_config
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Extract attention masks from trained Tacotron models.')
|
description='''Extract attention masks from trained Tacotron/Tacotron2 models.
|
||||||
|
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n'''
|
||||||
|
|
||||||
|
'''Each attention mask is written to the same path as the input wav file with ".npy" file extension.
|
||||||
|
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n'''
|
||||||
|
|
||||||
|
'''
|
||||||
|
Example run:
|
||||||
|
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
||||||
|
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
||||||
|
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
|
||||||
|
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
|
||||||
|
--data_path /root/LJSpeech-1.1/
|
||||||
|
--batch_size 32
|
||||||
|
--dataset ljspeech
|
||||||
|
--use_cuda True
|
||||||
|
''',
|
||||||
|
formatter_class=RawTextHelpFormatter
|
||||||
|
)
|
||||||
parser.add_argument('--model_path',
|
parser.add_argument('--model_path',
|
||||||
type=str,
|
type=str,
|
||||||
help='Path to Tacotron or Tacotron2 model file ')
|
required=True,
|
||||||
|
help='Path to Tacotron/Tacotron2 model file ')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--config_path',
|
'--config_path',
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help='Path to config file for training.',
|
help='Path to Tacotron/Tacotron2 config file.',
|
||||||
)
|
)
|
||||||
parser.add_argument('--dataset',
|
parser.add_argument('--dataset',
|
||||||
type=str,
|
type=str,
|
||||||
default='',
|
default='',
|
||||||
help='Dataset from TTS.tts.dataset.preprocess.')
|
required=True,
|
||||||
|
help='Target dataset processor name from TTS.tts.dataset.preprocess.')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dataset_metafile',
|
'--dataset_metafile',
|
||||||
type=str,
|
type=str,
|
||||||
default='',
|
default='',
|
||||||
|
required=True,
|
||||||
help='Dataset metafile inclusing file paths with transcripts.')
|
help='Dataset metafile inclusing file paths with transcripts.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--data_path',
|
'--data_path',
|
||||||
type=str,
|
type=str,
|
||||||
default='',
|
default='',
|
||||||
help='Defines the data path. It overwrites config.json.')
|
help='Defines the data path. It overwrites config.json.')
|
||||||
parser.add_argument('--output_path',
|
|
||||||
type=str,
|
|
||||||
help='path for training outputs.',
|
|
||||||
default='')
|
|
||||||
parser.add_argument('--output_folder',
|
|
||||||
type=str,
|
|
||||||
default='',
|
|
||||||
help='folder name for training outputs.')
|
|
||||||
|
|
||||||
parser.add_argument('--use_cuda',
|
parser.add_argument('--use_cuda',
|
||||||
type=bool,
|
type=bool,
|
||||||
default=False,
|
default=False,
|
||||||
|
@ -148,10 +147,8 @@ if __name__ == '__main__':
|
||||||
mode='nearest',
|
mode='nearest',
|
||||||
align_corners=None,
|
align_corners=None,
|
||||||
recompute_scale_factor=None).squeeze(0).transpose(0, 1)
|
recompute_scale_factor=None).squeeze(0).transpose(0, 1)
|
||||||
|
|
||||||
# remove paddings
|
# remove paddings
|
||||||
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
|
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
|
||||||
|
|
||||||
# set file paths
|
# set file paths
|
||||||
wav_file_name = os.path.basename(item_idx)
|
wav_file_name = os.path.basename(item_idx)
|
||||||
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
|
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
|
||||||
|
@ -160,7 +157,7 @@ if __name__ == '__main__':
|
||||||
file_paths.append([item_idx, file_path])
|
file_paths.append([item_idx, file_path])
|
||||||
np.save(file_path, alignment)
|
np.save(file_path, alignment)
|
||||||
|
|
||||||
# ourpur metafile
|
# ourput metafile
|
||||||
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
||||||
|
|
||||||
with open(metafile, "w") as f:
|
with open(metafile, "w") as f:
|
||||||
|
|
Loading…
Reference in New Issue