Use absolute paths of the attention masks

This commit is contained in:
Eren Gölge 2021-07-16 12:13:33 +02:00
parent bc396c393f
commit 545a00fc04
3 changed files with 30 additions and 6 deletions

View File

@ -148,10 +148,12 @@ Example run:
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
# set file paths
wav_file_name = os.path.basename(item_idx)
align_file_name = os.path.splitext(wav_file_name)[0] + ".npy"
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
file_path = item_idx.replace(wav_file_name, align_file_name)
# save output
file_paths.append([item_idx, file_path])
wav_file_abs_path = os.path.abspath(item_idx)
file_abs_path = os.path.abspath(file_path)
file_paths.append([wav_file_abs_path, file_abs_path])
np.save(file_path, alignment)
# ourput metafile

View File

@ -1,6 +1,7 @@
import sys
from collections import Counter
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
@ -30,7 +31,17 @@ def split_dataset(items):
return items[:eval_split_size], items[eval_split_size:]
def load_meta_data(datasets, eval_split=True):
def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]:
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
Args:
datasets (List[Dict]): A list of dataset dictionaries or dataset configs.
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
an eval split automatically. Defaults to True.
Returns:
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
"""
meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None
for dataset in datasets:
@ -51,15 +62,15 @@ def load_meta_data(datasets, eval_split=True):
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
# load attention masks for duration predictor training
# load attention masks for the duration predictor training
if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip()
attn_file = meta_data[os.path.abspath(ins[1])].strip()
meta_data_train_all[idx].append(attn_file)
if meta_data_eval_all:
for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip()
attn_file = meta_data[os.path.abspath(ins[1])].strip()
meta_data_eval_all[idx].append(attn_file)
return meta_data_train_all, meta_data_eval_all

View File

@ -3,8 +3,11 @@ import os
from TTS.config import BaseAudioConfig, BaseDatasetConfig
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.tts.configs import FastPitchConfig
from TTS.utils.manage import ModelManager
output_path = os.path.dirname(os.path.abspath(__file__))
# init configs
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"))
audio_config = BaseAudioConfig(
sample_rate=22050,
@ -40,6 +43,14 @@ config = FastPitchConfig(
output_path=output_path,
datasets=[dataset_config]
)
# compute alignments
manager = ModelManager()
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
# TODO: make compute_attention python callable
os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true")
# train the model
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
trainer.fit()