diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 7de3989d..fc8c6629 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -148,10 +148,12 @@ Example run: alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() # set file paths wav_file_name = os.path.basename(item_idx) - align_file_name = os.path.splitext(wav_file_name)[0] + ".npy" + align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" file_path = item_idx.replace(wav_file_name, align_file_name) # save output - file_paths.append([item_idx, file_path]) + wav_file_abs_path = os.path.abspath(item_idx) + file_abs_path = os.path.abspath(file_path) + file_paths.append([wav_file_abs_path, file_abs_path]) np.save(file_path, alignment) # ourput metafile diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index a2520751..2e315963 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -1,6 +1,7 @@ import sys from collections import Counter from pathlib import Path +from typing import Dict, List, Tuple import numpy as np @@ -30,7 +31,17 @@ def split_dataset(items): return items[:eval_split_size], items[eval_split_size:] -def load_meta_data(datasets, eval_split=True): +def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]: + """Parse the dataset, load the samples as a list and load the attention alignments if provided. + + Args: + datasets (List[Dict]): A list of dataset dictionaries or dataset configs. + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate + an eval split automatically. Defaults to True. + + Returns: + Tuple[List[List], List[List]: training and evaluation splits of the dataset. + """ meta_data_train_all = [] meta_data_eval_all = [] if eval_split else None for dataset in datasets: @@ -51,15 +62,15 @@ def load_meta_data(datasets, eval_split=True): meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train - # load attention masks for duration predictor training + # load attention masks for the duration predictor training if dataset.meta_file_attn_mask: meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): - attn_file = meta_data[ins[1]].strip() + attn_file = meta_data[os.path.abspath(ins[1])].strip() meta_data_train_all[idx].append(attn_file) if meta_data_eval_all: for idx, ins in enumerate(meta_data_eval_all): - attn_file = meta_data[ins[1]].strip() + attn_file = meta_data[os.path.abspath(ins[1])].strip() meta_data_eval_all[idx].append(attn_file) return meta_data_train_all, meta_data_eval_all diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index e3bd131e..91fe4bd2 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -3,8 +3,11 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig from TTS.trainer import Trainer, TrainingArgs, init_training from TTS.tts.configs import FastPitchConfig +from TTS.utils.manage import ModelManager output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/")) audio_config = BaseAudioConfig( sample_rate=22050, @@ -40,6 +43,14 @@ config = FastPitchConfig( output_path=output_path, datasets=[dataset_config] ) + +# compute alignments +manager = ModelManager() +model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") +# TODO: make compute_attention python callable +os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true") + +# train the model args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) trainer = Trainer(args, config, output_path, c_logger, tb_logger) trainer.fit()