Add YourTTS VCTK recipe (#2198)

* Add YourTTS VCTK recipe

* Fix lint

* Add compute_embeddings and resample_files functions to be able to reuse it

* Add automatic download and speaker embedding computation for YourTTS VCTK recipe

* Add parameter for eval metadata file on compute embeddings function
This commit is contained in:
Edresson Casanova 2022-12-12 12:14:25 -03:00 committed by GitHub
parent 3b8b105b0d
commit 3b1a28fa95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 393 additions and 126 deletions

View File

@ -11,121 +11,162 @@ from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.managers import save_file
from TTS.tts.utils.speakers import SpeakerManager
parser = argparse.ArgumentParser(
description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
"""
Example runs:
python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --config_dataset_path dataset_config.json
python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --fomatter vctk --dataset_path /path/to/vctk/dataset --dataset_name my_vctk --metafile /path/to/vctk/metafile.csv
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--model_path",
type=str,
help="Path to model checkpoint file. It defaults to the released speaker encoder.",
default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar",
)
parser.add_argument(
"--config_path",
type=str,
help="Path to model config file. It defaults to the released speaker encoder config.",
default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json",
)
parser.add_argument(
"--config_dataset_path",
type=str,
help="Path to dataset config file. You either need to provide this or `formatter_name`, `dataset_name` and `dataset_path` arguments.",
default=None,
)
parser.add_argument("--output_path", type=str, help="Path for output `pth` or `json` file.", default="speakers.pth")
parser.add_argument("--old_file", type=str, help="Previous embedding file to only compute new audios.", default=None)
parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False)
parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False)
parser.add_argument(
"--formatter_name",
type=str,
help="Name of the formatter to use. You either need to provide this or `config_dataset_path`",
default=None,
)
parser.add_argument(
"--dataset_name",
type=str,
help="Name of the dataset to use. You either need to provide this or `config_dataset_path`",
default=None,
)
parser.add_argument(
"--dataset_path",
type=str,
help="Path to the dataset. You either need to provide this or `config_dataset_path`",
default=None,
)
parser.add_argument(
"--metafile",
type=str,
help="Path to the meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`",
default=None,
)
args = parser.parse_args()
def compute_embeddings(
model_path,
config_path,
output_path,
old_spakers_file=None,
config_dataset_path=None,
formatter_name=None,
dataset_name=None,
dataset_path=None,
meta_file_train=None,
meta_file_val=None,
disable_cuda=False,
no_eval=False,
):
use_cuda = torch.cuda.is_available() and not disable_cuda
use_cuda = torch.cuda.is_available() and not args.disable_cuda
if args.config_dataset_path is not None:
c_dataset = load_config(args.config_dataset_path)
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not args.no_eval)
else:
c_dataset = BaseDatasetConfig()
c_dataset.formatter = args.formatter_name
c_dataset.dataset_name = args.dataset_name
c_dataset.path = args.dataset_path
c_dataset.meta_file_train = args.metafile if args.metafile else None
meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not args.no_eval)
if meta_data_eval is None:
samples = meta_data_train
else:
samples = meta_data_train + meta_data_eval
encoder_manager = SpeakerManager(
encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=use_cuda,
)
class_name_key = encoder_manager.encoder_config.class_name_key
# compute speaker embeddings
speaker_mapping = {}
for idx, fields in enumerate(tqdm(samples)):
class_name = fields[class_name_key]
audio_file = fields["audio_file"]
embedding_key = fields["audio_unique_name"]
root_path = fields["root_path"]
if args.old_file is not None and embedding_key in encoder_manager.clip_ids:
# get the embedding from the old file
embedd = encoder_manager.get_embedding_by_clip(embedding_key)
if config_dataset_path is not None:
c_dataset = load_config(config_dataset_path)
meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not no_eval)
else:
# extract the embedding
embedd = encoder_manager.compute_embedding_from_clip(audio_file)
c_dataset = BaseDatasetConfig()
c_dataset.formatter = formatter_name
c_dataset.dataset_name = dataset_name
c_dataset.path = dataset_path
if meta_file_train is not None:
c_dataset.meta_file_train = meta_file_train
if meta_file_val is not None:
c_dataset.meta_file_val = meta_file_val
meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not no_eval)
# create speaker_mapping if target dataset is defined
speaker_mapping[embedding_key] = {}
speaker_mapping[embedding_key]["name"] = class_name
speaker_mapping[embedding_key]["embedding"] = embedd
if speaker_mapping:
# save speaker_mapping if target dataset is defined
if os.path.isdir(args.output_path):
mapping_file_path = os.path.join(args.output_path, "speakers.pth")
if meta_data_eval is None:
samples = meta_data_train
else:
mapping_file_path = args.output_path
samples = meta_data_train + meta_data_eval
if os.path.dirname(mapping_file_path) != "":
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
encoder_manager = SpeakerManager(
encoder_model_path=model_path,
encoder_config_path=config_path,
d_vectors_file_path=old_spakers_file,
use_cuda=use_cuda,
)
save_file(speaker_mapping, mapping_file_path)
print("Speaker embeddings saved at:", mapping_file_path)
class_name_key = encoder_manager.encoder_config.class_name_key
# compute speaker embeddings
speaker_mapping = {}
for fields in tqdm(samples):
class_name = fields[class_name_key]
audio_file = fields["audio_file"]
embedding_key = fields["audio_unique_name"]
if old_spakers_file is not None and embedding_key in encoder_manager.clip_ids:
# get the embedding from the old file
embedd = encoder_manager.get_embedding_by_clip(embedding_key)
else:
# extract the embedding
embedd = encoder_manager.compute_embedding_from_clip(audio_file)
# create speaker_mapping if target dataset is defined
speaker_mapping[embedding_key] = {}
speaker_mapping[embedding_key]["name"] = class_name
speaker_mapping[embedding_key]["embedding"] = embedd
if speaker_mapping:
# save speaker_mapping if target dataset is defined
if os.path.isdir(output_path):
mapping_file_path = os.path.join(output_path, "speakers.pth")
else:
mapping_file_path = output_path
if os.path.dirname(mapping_file_path) != "":
os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True)
save_file(speaker_mapping, mapping_file_path)
print("Speaker embeddings saved at:", mapping_file_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n"""
"""
Example runs:
python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --config_dataset_path dataset_config.json
python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --formatter_name coqui --dataset_path /path/to/vctk/dataset --dataset_name my_vctk --meta_file_train /path/to/vctk/metafile_train.csv --meta_file_val /path/to/vctk/metafile_eval.csv
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--model_path",
type=str,
help="Path to model checkpoint file. It defaults to the released speaker encoder.",
default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar",
)
parser.add_argument(
"--config_path",
type=str,
help="Path to model config file. It defaults to the released speaker encoder config.",
default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json",
)
parser.add_argument(
"--config_dataset_path",
type=str,
help="Path to dataset config file. You either need to provide this or `formatter_name`, `dataset_name` and `dataset_path` arguments.",
default=None,
)
parser.add_argument("--output_path", type=str, help="Path for output `pth` or `json` file.", default="speakers.pth")
parser.add_argument(
"--old_file", type=str, help="Previous embedding file to only compute new audios.", default=None
)
parser.add_argument("--disable_cuda", type=bool, help="Flag to disable cuda.", default=False)
parser.add_argument("--no_eval", type=bool, help="Do not compute eval?. Default False", default=False)
parser.add_argument(
"--formatter_name",
type=str,
help="Name of the formatter to use. You either need to provide this or `config_dataset_path`",
default=None,
)
parser.add_argument(
"--dataset_name",
type=str,
help="Name of the dataset to use. You either need to provide this or `config_dataset_path`",
default=None,
)
parser.add_argument(
"--dataset_path",
type=str,
help="Path to the dataset. You either need to provide this or `config_dataset_path`",
default=None,
)
parser.add_argument(
"--meta_file_train",
type=str,
help="Path to the train meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`",
default=None,
)
parser.add_argument(
"--meta_file_val",
type=str,
help="Path to the evaluation meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`",
default=None,
)
args = parser.parse_args()
compute_embeddings(
args.model_path,
args.config_path,
args.output_path,
old_spakers_file=args.old_file,
config_dataset_path=args.config_dataset_path,
formatter_name=args.formatter_name,
dataset_name=args.dataset_name,
dataset_path=args.dataset_path,
meta_file_train=args.meta_file_train,
meta_file_val=args.meta_file_val,
disable_cuda=args.disable_cuda,
no_eval=args.no_eval,
)

View File

@ -16,6 +16,24 @@ def resample_file(func_args):
sf.write(filename, y, sr)
def resample_files(input_dir, output_sr, output_dir=None, file_ext="wav", n_jobs=10):
if output_dir:
print("Recursively copying the input folder...")
copy_tree(input_dir, output_dir)
input_dir = output_dir
print("Resampling the audio files...")
audio_files = glob.glob(os.path.join(input_dir, f"**/*.{file_ext}"), recursive=True)
print(f"Found {len(audio_files)} files...")
audio_files = list(zip(audio_files, len(audio_files) * [output_sr]))
with Pool(processes=n_jobs) as p:
with tqdm(total=len(audio_files)) as pbar:
for _, _ in enumerate(p.imap_unordered(resample_file, audio_files)):
pbar.update()
print("Done !")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@ -70,18 +88,4 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.output_dir:
print("Recursively copying the input folder...")
copy_tree(args.input_dir, args.output_dir)
args.input_dir = args.output_dir
print("Resampling the audio files...")
audio_files = glob.glob(os.path.join(args.input_dir, f"**/*.{args.file_ext}"), recursive=True)
print(f"Found {len(audio_files)} files...")
audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr]))
with Pool(processes=args.n_jobs) as p:
with tqdm(total=len(audio_files)) as pbar:
for i, _ in enumerate(p.imap_unordered(resample_file, audio_files)):
pbar.update()
print("Done !")
resample_files(args.input_dir, args.output_sr, args.output_dir, args.file_ext, args.n_jobs)

View File

@ -0,0 +1,222 @@
import os
import torch
from trainer import Trainer, TrainerArgs
from TTS.bin.compute_embeddings import compute_embeddings
from TTS.bin.resample import resample_files
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.vits import Vits, VitsArgs, VitsAudioConfig
from TTS.utils.downloaders import download_vctk
torch.set_num_threads(24)
# pylint: disable=W0105
"""
This recipe replicates the first experiment proposed in the YourTTS paper (https://arxiv.org/abs/2112.02418).
YourTTS model is based on the VITS model however it uses external speaker embeddings extracted from a pre-trained speaker encoder and has small architecture changes.
In addition, YourTTS can be trained in multilingual data, however, this recipe replicates the single language training using the VCTK dataset.
If you are interested in multilingual training, we have commented on parameters on the VitsArgs class instance that should be enabled for multilingual training.
In addition, you will need to add the extra datasets following the VCTK as an example.
"""
CURRENT_PATH = os.path.dirname(os.path.abspath(__file__))
# Name of the run for the Trainer
RUN_NAME = "YourTTS-EN-VCTK"
# Path where you want to save the models outputs (configs, checkpoints and tensorboard logs)
OUT_PATH = os.path.dirname(os.path.abspath(__file__)) # "/raid/coqui/Checkpoints/original-YourTTS/"
# If you want to do transfer learning and speedup your training you can set here the path to the original YourTTS model
RESTORE_PATH = None # "/root/.local/share/tts/tts_models--multilingual--multi-dataset--your_tts/model_file.pth"
# This paramter is usefull to debug, it skips the training epochs and just do the evaluation and produce the test sentences
SKIP_TRAIN_EPOCH = False
# Set here the batch size to be used in training and evaluation
BATCH_SIZE = 32
# Training Sampling rate and the target sampling rate for resampling the downloaded dataset (Note: If you change this you might need to redownload the dataset !!)
# Note: If you add new datasets, please make sure that the dataset sampling rate and this parameter are matching, otherwise resample your audios
SAMPLE_RATE = 16000
# Max audio length in seconds to be used in training (every audio bigger than it will be ignored)
MAX_AUDIO_LEN_IN_SECONDS = 10
### Download VCTK dataset
VCTK_DOWNLOAD_PATH = os.path.join(CURRENT_PATH, "VCTK")
# Define the number of threads used during the audio resampling
NUM_RESAMPLE_THREADS = 10
# Check if VCTK dataset is not already downloaded, if not download it
if not os.path.exists(VCTK_DOWNLOAD_PATH):
print(">>> Downloading VCTK dataset:")
download_vctk(VCTK_DOWNLOAD_PATH)
resample_files(VCTK_DOWNLOAD_PATH, SAMPLE_RATE, file_ext="flac", n_jobs=NUM_RESAMPLE_THREADS)
# init configs
vctk_config = BaseDatasetConfig(
formatter="vctk", dataset_name="vctk", meta_file_train="", meta_file_val="", path=VCTK_DOWNLOAD_PATH, language="en"
)
# Add here all datasets configs, in our case we just want to train with the VCTK dataset then we need to add just VCTK. Note: If you want to added new datasets just added they here and it will automatically compute the speaker embeddings (d-vectors) for this new dataset :)
DATASETS_CONFIG_LIST = [vctk_config]
### Extract speaker embeddings
SPEAKER_ENCODER_CHECKPOINT_PATH = (
"https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar"
)
SPEAKER_ENCODER_CONFIG_PATH = "https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json"
D_VECTOR_FILES = [] # List of speaker embeddings/d-vectors to be used during the training
# Iterates all the dataset configs checking if the speakers embeddings are already computated, if not compute it
for dataset_conf in DATASETS_CONFIG_LIST:
# Check if the embeddings weren't already computed, if not compute it
embeddings_file = os.path.join(dataset_conf.path, "speakers.pth")
if not os.path.isfile(embeddings_file):
print(f">>> Computing the speaker embeddings for the {dataset_conf.dataset_name} dataset")
compute_embeddings(
SPEAKER_ENCODER_CHECKPOINT_PATH,
SPEAKER_ENCODER_CONFIG_PATH,
embeddings_file,
old_spakers_file=None,
config_dataset_path=None,
formatter_name=dataset_conf.formatter,
dataset_name=dataset_conf.dataset_name,
dataset_path=dataset_conf.path,
meta_file_train=dataset_conf.meta_file_train,
meta_file_val=dataset_conf.meta_file_val,
disable_cuda=False,
no_eval=False,
)
D_VECTOR_FILES.append(embeddings_file)
# Audio config used in training.
audio_config = VitsAudioConfig(
sample_rate=SAMPLE_RATE,
hop_length=256,
win_length=1024,
fft_size=1024,
mel_fmin=0.0,
mel_fmax=None,
num_mels=80,
)
# Init VITSArgs setting the arguments that is needed for the YourTTS model
model_args = VitsArgs(
d_vector_file=D_VECTOR_FILES,
use_d_vector_file=True,
d_vector_dim=512,
num_layers_text_encoder=10,
resblock_type_decoder="2", # On the paper, we accidentally trained the YourTTS using ResNet blocks type 2, if you like you can use the ResNet blocks type 1 like the VITS model
# Usefull parameters to enable the Speaker Consistency Loss (SCL) discribed in the paper
# use_speaker_encoder_as_loss=True,
# speaker_encoder_model_path=SPEAKER_ENCODER_CHECKPOINT_PATH,
# speaker_encoder_config_path=SPEAKER_ENCODER_CONFIG_PATH,
# Usefull parameters to the enable multilingual training
# use_language_embedding=True,
# embedded_language_dim=4,
)
# General training config, here you can change the batch size and others usefull parameters
config = VitsConfig(
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
project_name="YourTTS",
run_description="""
- Original YourTTS trained using VCTK dataset
""",
dashboard_logger="tensorboard",
logger_uri=None,
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=8,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=1000,
save_step=5000,
save_n_checkpoints=2,
save_checkpoints=True,
target_loss="loss_1",
print_eval=False,
use_phonemes=False,
phonemizer="espeak",
phoneme_language="en",
compute_input_seq_cache=True,
add_blank=True,
text_cleaner="english_cleaners",
phoneme_cache_path=None,
precompute_num_workers=12,
start_by_longest=True,
datasets=DATASETS_CONFIG_LIST,
cudnn_benchmark=False,
max_audio_len=SAMPLE_RATE * MAX_AUDIO_LEN_IN_SECONDS,
mixed_precision=False,
test_sentences=[
[
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"VCTK_p277",
None,
"en",
],
[
"Be a voice, not an echo.",
"VCTK_p239",
None,
"en",
],
[
"I'm sorry Dave. I'm afraid I can't do that.",
"VCTK_p258",
None,
"en",
],
[
"This cake is great. It's so delicious and moist.",
"VCTK_p244",
None,
"en",
],
[
"Prior to November 22, 1963.",
"VCTK_p305",
None,
"en",
],
],
# Enable the weighted sampler
use_weighted_sampler=True,
# Ensures that all speakers are seen in the training batch equally no matter how many samples each speaker has
weighted_sampler_attrs={"speaker_name": 1.0},
# It defines the Speaker Consistency Loss (SCL) α to 9 like the paper
speaker_encoder_loss_alpha=9.0,
)
# Load all the datasets samples and split traning and evaluation sets
train_samples, eval_samples = load_tts_samples(
config.datasets,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# Init the model
model = Vits.init_from_config(config)
# Init the trainer and 🚀
trainer = Trainer(
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
trainer.fit()