Use absolute paths of the attention masks

This commit is contained in:
Eren Gölge 2021-07-16 12:13:33 +02:00
parent bc396c393f
commit 545a00fc04
3 changed files with 30 additions and 6 deletions

View File

@ -148,10 +148,12 @@ Example run:
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
# set file paths # set file paths
wav_file_name = os.path.basename(item_idx) wav_file_name = os.path.basename(item_idx)
align_file_name = os.path.splitext(wav_file_name)[0] + ".npy" align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
file_path = item_idx.replace(wav_file_name, align_file_name) file_path = item_idx.replace(wav_file_name, align_file_name)
# save output # save output
file_paths.append([item_idx, file_path]) wav_file_abs_path = os.path.abspath(item_idx)
file_abs_path = os.path.abspath(file_path)
file_paths.append([wav_file_abs_path, file_abs_path])
np.save(file_path, alignment) np.save(file_path, alignment)
# ourput metafile # ourput metafile

View File

@ -1,6 +1,7 @@
import sys import sys
from collections import Counter from collections import Counter
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np import numpy as np
@ -30,7 +31,17 @@ def split_dataset(items):
return items[:eval_split_size], items[eval_split_size:] return items[:eval_split_size], items[eval_split_size:]
def load_meta_data(datasets, eval_split=True): def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]:
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
Args:
datasets (List[Dict]): A list of dataset dictionaries or dataset configs.
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
an eval split automatically. Defaults to True.
Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
"""
meta_data_train_all = [] meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None meta_data_eval_all = [] if eval_split else None
for dataset in datasets: for dataset in datasets:
@ -51,15 +62,15 @@ def load_meta_data(datasets, eval_split=True):
meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train meta_data_train_all += meta_data_train
# load attention masks for duration predictor training # load attention masks for the duration predictor training
if dataset.meta_file_attn_mask: if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all): for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[os.path.abspath(ins[1])].strip()
meta_data_train_all[idx].append(attn_file) meta_data_train_all[idx].append(attn_file)
if meta_data_eval_all: if meta_data_eval_all:
for idx, ins in enumerate(meta_data_eval_all): for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[os.path.abspath(ins[1])].strip()
meta_data_eval_all[idx].append(attn_file) meta_data_eval_all[idx].append(attn_file)
return meta_data_train_all, meta_data_eval_all return meta_data_train_all, meta_data_eval_all

View File

@ -3,8 +3,11 @@ import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import FastPitchConfig from TTS.tts.configs import FastPitchConfig
from TTS.utils.manage import ModelManager
output_path = os.path.dirname(os.path.abspath(__file__)) output_path = os.path.dirname(os.path.abspath(__file__))
# init configs
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/")) dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"))
audio_config = BaseAudioConfig( audio_config = BaseAudioConfig(
sample_rate=22050, sample_rate=22050,
@ -40,6 +43,14 @@ config = FastPitchConfig(
output_path=output_path, output_path=output_path,
datasets=[dataset_config] datasets=[dataset_config]
) )
# compute alignments
manager = ModelManager()
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
# TODO: make compute_attention python callable
os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true")
# train the model
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer = Trainer(args, config, output_path, c_logger, tb_logger)
trainer.fit() trainer.fit()