mirror of https://github.com/coqui-ai/TTS.git
Use absolute paths of the attention masks
This commit is contained in:
parent
bc396c393f
commit
545a00fc04
|
@ -148,10 +148,12 @@ Example run:
|
|||
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
|
||||
# set file paths
|
||||
wav_file_name = os.path.basename(item_idx)
|
||||
align_file_name = os.path.splitext(wav_file_name)[0] + ".npy"
|
||||
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
|
||||
file_path = item_idx.replace(wav_file_name, align_file_name)
|
||||
# save output
|
||||
file_paths.append([item_idx, file_path])
|
||||
wav_file_abs_path = os.path.abspath(item_idx)
|
||||
file_abs_path = os.path.abspath(file_path)
|
||||
file_paths.append([wav_file_abs_path, file_abs_path])
|
||||
np.save(file_path, alignment)
|
||||
|
||||
# ourput metafile
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -30,7 +31,17 @@ def split_dataset(items):
|
|||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
def load_meta_data(datasets, eval_split=True):
|
||||
def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]:
|
||||
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
|
||||
|
||||
Args:
|
||||
datasets (List[Dict]): A list of dataset dictionaries or dataset configs.
|
||||
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
|
||||
an eval split automatically. Defaults to True.
|
||||
|
||||
Returns:
|
||||
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
|
||||
"""
|
||||
meta_data_train_all = []
|
||||
meta_data_eval_all = [] if eval_split else None
|
||||
for dataset in datasets:
|
||||
|
@ -51,15 +62,15 @@ def load_meta_data(datasets, eval_split=True):
|
|||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
meta_data_train_all += meta_data_train
|
||||
# load attention masks for duration predictor training
|
||||
# load attention masks for the duration predictor training
|
||||
if dataset.meta_file_attn_mask:
|
||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||
for idx, ins in enumerate(meta_data_train_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
attn_file = meta_data[os.path.abspath(ins[1])].strip()
|
||||
meta_data_train_all[idx].append(attn_file)
|
||||
if meta_data_eval_all:
|
||||
for idx, ins in enumerate(meta_data_eval_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
attn_file = meta_data[os.path.abspath(ins[1])].strip()
|
||||
meta_data_eval_all[idx].append(attn_file)
|
||||
return meta_data_train_all, meta_data_eval_all
|
||||
|
||||
|
|
|
@ -3,8 +3,11 @@ import os
|
|||
from TTS.config import BaseAudioConfig, BaseDatasetConfig
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.tts.configs import FastPitchConfig
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# init configs
|
||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"))
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
|
@ -40,6 +43,14 @@ config = FastPitchConfig(
|
|||
output_path=output_path,
|
||||
datasets=[dataset_config]
|
||||
)
|
||||
|
||||
# compute alignments
|
||||
manager = ModelManager()
|
||||
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
|
||||
# TODO: make compute_attention python callable
|
||||
os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true")
|
||||
|
||||
# train the model
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
trainer.fit()
|
||||
|
|
Loading…
Reference in New Issue