update compute_attention_masks.py

This commit is contained in:
root 2021-01-13 10:03:57 +00:00
parent 0a9767afd7
commit 7beaacc55b
1 changed files with 27 additions and 30 deletions

View File

@ -1,17 +1,3 @@
"""Compute attention masks from pre-trained Tacotron or Tacotron2 models.
Sample run on LJSpeech dataset.
>>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \
--model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \
--config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \
--dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \
--data_path /home/erogol/Data/LJSpeech-1.1/ \
--batch_size 16 \
--use_cuda true
"""
import argparse import argparse
import importlib import importlib
import os import os
@ -20,6 +6,7 @@ import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from argparse import RawTextHelpFormatter
from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import load_checkpoint from TTS.tts.utils.io import load_checkpoint
@ -30,40 +17,52 @@ from TTS.utils.io import load_config
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Extract attention masks from trained Tacotron models.') description='''Extract attention masks from trained Tacotron/Tacotron2 models.
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n'''
'''Each attention mask is written to the same path as the input wav file with ".npy" file extension.
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n'''
'''
Example run:
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json
--dataset_metafile /root/LJSpeech-1.1/metadata.csv
--data_path /root/LJSpeech-1.1/
--batch_size 32
--dataset ljspeech
--use_cuda True
''',
formatter_class=RawTextHelpFormatter
)
parser.add_argument('--model_path', parser.add_argument('--model_path',
type=str, type=str,
help='Path to Tacotron or Tacotron2 model file ') required=True,
help='Path to Tacotron/Tacotron2 model file ')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
type=str, type=str,
required=True, required=True,
help='Path to config file for training.', help='Path to Tacotron/Tacotron2 config file.',
) )
parser.add_argument('--dataset', parser.add_argument('--dataset',
type=str, type=str,
default='', default='',
help='Dataset from TTS.tts.dataset.preprocess.') required=True,
help='Target dataset processor name from TTS.tts.dataset.preprocess.')
parser.add_argument( parser.add_argument(
'--dataset_metafile', '--dataset_metafile',
type=str, type=str,
default='', default='',
required=True,
help='Dataset metafile inclusing file paths with transcripts.') help='Dataset metafile inclusing file paths with transcripts.')
parser.add_argument( parser.add_argument(
'--data_path', '--data_path',
type=str, type=str,
default='', default='',
help='Defines the data path. It overwrites config.json.') help='Defines the data path. It overwrites config.json.')
parser.add_argument('--output_path',
type=str,
help='path for training outputs.',
default='')
parser.add_argument('--output_folder',
type=str,
default='',
help='folder name for training outputs.')
parser.add_argument('--use_cuda', parser.add_argument('--use_cuda',
type=bool, type=bool,
default=False, default=False,
@ -148,10 +147,8 @@ if __name__ == '__main__':
mode='nearest', mode='nearest',
align_corners=None, align_corners=None,
recompute_scale_factor=None).squeeze(0).transpose(0, 1) recompute_scale_factor=None).squeeze(0).transpose(0, 1)
# remove paddings # remove paddings
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy() alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
# set file paths # set file paths
wav_file_name = os.path.basename(item_idx) wav_file_name = os.path.basename(item_idx)
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy' align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
@ -160,7 +157,7 @@ if __name__ == '__main__':
file_paths.append([item_idx, file_path]) file_paths.append([item_idx, file_path])
np.save(file_path, alignment) np.save(file_path, alignment)
# ourpur metafile # ourput metafile
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
with open(metafile, "w") as f: with open(metafile, "w") as f: