mirror of https://github.com/coqui-ai/TTS.git
Use absolute paths of the attention masks
This commit is contained in:
parent
bc396c393f
commit
545a00fc04
|
@ -148,10 +148,12 @@ Example run:
|
||||||
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
|
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
|
||||||
# set file paths
|
# set file paths
|
||||||
wav_file_name = os.path.basename(item_idx)
|
wav_file_name = os.path.basename(item_idx)
|
||||||
align_file_name = os.path.splitext(wav_file_name)[0] + ".npy"
|
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy"
|
||||||
file_path = item_idx.replace(wav_file_name, align_file_name)
|
file_path = item_idx.replace(wav_file_name, align_file_name)
|
||||||
# save output
|
# save output
|
||||||
file_paths.append([item_idx, file_path])
|
wav_file_abs_path = os.path.abspath(item_idx)
|
||||||
|
file_abs_path = os.path.abspath(file_path)
|
||||||
|
file_paths.append([wav_file_abs_path, file_abs_path])
|
||||||
np.save(file_path, alignment)
|
np.save(file_path, alignment)
|
||||||
|
|
||||||
# ourput metafile
|
# ourput metafile
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import sys
|
import sys
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -30,7 +31,17 @@ def split_dataset(items):
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
def load_meta_data(datasets, eval_split=True):
|
def load_meta_data(datasets: List[Dict], eval_split=True) -> Tuple[List[List], List[List]]:
|
||||||
|
"""Parse the dataset, load the samples as a list and load the attention alignments if provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datasets (List[Dict]): A list of dataset dictionaries or dataset configs.
|
||||||
|
eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate
|
||||||
|
an eval split automatically. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[List], List[List]: training and evaluation splits of the dataset.
|
||||||
|
"""
|
||||||
meta_data_train_all = []
|
meta_data_train_all = []
|
||||||
meta_data_eval_all = [] if eval_split else None
|
meta_data_eval_all = [] if eval_split else None
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
|
@ -51,15 +62,15 @@ def load_meta_data(datasets, eval_split=True):
|
||||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||||
meta_data_eval_all += meta_data_eval
|
meta_data_eval_all += meta_data_eval
|
||||||
meta_data_train_all += meta_data_train
|
meta_data_train_all += meta_data_train
|
||||||
# load attention masks for duration predictor training
|
# load attention masks for the duration predictor training
|
||||||
if dataset.meta_file_attn_mask:
|
if dataset.meta_file_attn_mask:
|
||||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||||
for idx, ins in enumerate(meta_data_train_all):
|
for idx, ins in enumerate(meta_data_train_all):
|
||||||
attn_file = meta_data[ins[1]].strip()
|
attn_file = meta_data[os.path.abspath(ins[1])].strip()
|
||||||
meta_data_train_all[idx].append(attn_file)
|
meta_data_train_all[idx].append(attn_file)
|
||||||
if meta_data_eval_all:
|
if meta_data_eval_all:
|
||||||
for idx, ins in enumerate(meta_data_eval_all):
|
for idx, ins in enumerate(meta_data_eval_all):
|
||||||
attn_file = meta_data[ins[1]].strip()
|
attn_file = meta_data[os.path.abspath(ins[1])].strip()
|
||||||
meta_data_eval_all[idx].append(attn_file)
|
meta_data_eval_all[idx].append(attn_file)
|
||||||
return meta_data_train_all, meta_data_eval_all
|
return meta_data_train_all, meta_data_eval_all
|
||||||
|
|
||||||
|
|
|
@ -3,8 +3,11 @@ import os
|
||||||
from TTS.config import BaseAudioConfig, BaseDatasetConfig
|
from TTS.config import BaseAudioConfig, BaseDatasetConfig
|
||||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||||
from TTS.tts.configs import FastPitchConfig
|
from TTS.tts.configs import FastPitchConfig
|
||||||
|
from TTS.utils.manage import ModelManager
|
||||||
|
|
||||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# init configs
|
||||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"))
|
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"))
|
||||||
audio_config = BaseAudioConfig(
|
audio_config = BaseAudioConfig(
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
@ -40,6 +43,14 @@ config = FastPitchConfig(
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
datasets=[dataset_config]
|
datasets=[dataset_config]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# compute alignments
|
||||||
|
manager = ModelManager()
|
||||||
|
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
|
||||||
|
# TODO: make compute_attention python callable
|
||||||
|
os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true")
|
||||||
|
|
||||||
|
# train the model
|
||||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||||
trainer.fit()
|
trainer.fit()
|
||||||
|
|
Loading…
Reference in New Issue