mirror of https://github.com/coqui-ai/TTS.git
Add syntacc training recipe
This commit is contained in:
parent
2bdc7a5675
commit
a45dfd6266
|
@ -6,8 +6,7 @@ from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2
|
||||||
# import sys
|
|
||||||
# sys.setrecursionlimit(9999999)
|
|
||||||
class AdaptiveWeightConv(nn.Module):
|
class AdaptiveWeightConv(nn.Module):
|
||||||
def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, alpha=1, dropout=0., num_classes=None, **kwargs):
|
def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, alpha=1, dropout=0., num_classes=None, **kwargs):
|
||||||
super(AdaptiveWeightConv, self).__init__()
|
super(AdaptiveWeightConv, self).__init__()
|
||||||
|
@ -558,7 +557,7 @@ class TextEncoder(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
|
self.num_adaptive_weight_classes = num_adaptive_weight_classes
|
||||||
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
||||||
|
|
||||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||||
|
@ -582,12 +581,7 @@ class TextEncoder(nn.Module):
|
||||||
|
|
||||||
self.proj = Conv1d(hidden_channels, out_channels * 2, 1, r=1 if num_adaptive_weight_classes else 0, num_classes=num_adaptive_weight_classes)
|
self.proj = Conv1d(hidden_channels, out_channels * 2, 1, r=1 if num_adaptive_weight_classes else 0, num_classes=num_adaptive_weight_classes)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, lang_emb=None, class_id=None):
|
def forward_mini_batch(self, x, x_lengths, lang_emb=None, class_id=None):
|
||||||
"""
|
|
||||||
Shapes:
|
|
||||||
- x: :math:`[B, T]`
|
|
||||||
- x_length: :math:`[B]`
|
|
||||||
"""
|
|
||||||
assert x.shape[0] == x_lengths.shape[0]
|
assert x.shape[0] == x_lengths.shape[0]
|
||||||
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
||||||
|
|
||||||
|
@ -604,6 +598,41 @@ class TextEncoder(nn.Module):
|
||||||
m, logs = torch.split(stats, self.out_channels, dim=1)
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
return x, m, logs, x_mask
|
return x, m, logs, x_mask
|
||||||
|
|
||||||
|
def forward(self, x, x_lengths, lang_emb=None, class_id=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T]`
|
||||||
|
- x_length: :math:`[B]`
|
||||||
|
"""
|
||||||
|
batch_size = x.size(0)
|
||||||
|
if self.num_adaptive_weight_classes and batch_size > 1:
|
||||||
|
num_utter_per_class = int(batch_size/self.num_adaptive_weight_classes)
|
||||||
|
# mini batch inference for each class
|
||||||
|
outs_x = []
|
||||||
|
outs_m = []
|
||||||
|
outs_logs = []
|
||||||
|
outs_x_mask = []
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for i in range(self.num_adaptive_weight_classes):
|
||||||
|
start = num_utter_per_class * i
|
||||||
|
end = start + num_utter_per_class
|
||||||
|
class_id_item = class_id[start:end][0]
|
||||||
|
x_out, m_out, logs_out, x_mask_out = self.forward_mini_batch(x[start:end], x_lengths[start:end], lang_emb=lang_emb[start:end] if lang_emb else None, class_id=class_id_item)
|
||||||
|
outs_x.append(x_out)
|
||||||
|
outs_m.append(m_out)
|
||||||
|
outs_logs.append(logs_out)
|
||||||
|
outs_x_mask.append(x_mask_out)
|
||||||
|
|
||||||
|
x = torch.stack(outs_x, dim=0).view(batch_size, *x_out.shape[1:])
|
||||||
|
m = torch.stack(outs_m, dim=0).view(batch_size, *m_out.shape[1:])
|
||||||
|
logs = torch.stack(outs_logs, dim=0).view(batch_size, *logs_out.shape[1:])
|
||||||
|
x_mask = torch.stack(outs_x_mask, dim=0).view(batch_size, *x_mask_out.shape[1:])
|
||||||
|
return x, m, logs, x_mask
|
||||||
|
else:
|
||||||
|
return self.forward_mini_batch(x, x_lengths, lang_emb=lang_emb, class_id=class_id)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
txt_enc = TextEncoder(
|
txt_enc = TextEncoder(
|
||||||
n_vocab=100,
|
n_vocab=100,
|
||||||
|
@ -642,7 +671,7 @@ if __name__ == '__main__':
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
language_emb_dim=None,
|
language_emb_dim=None,
|
||||||
num_adaptive_weight_classes=5,
|
num_adaptive_weight_classes=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = txt_enc(
|
out = txt_enc(
|
||||||
|
|
|
@ -35,7 +35,7 @@ from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _chara
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
from TTS.utils.samplers import BucketBatchSampler
|
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
@ -259,6 +259,7 @@ class VitsDataset(TTSDataset):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.pad_id = self.tokenizer.characters.pad_id
|
self.pad_id = self.tokenizer.characters.pad_id
|
||||||
self.model_args = model_args
|
self.model_args = model_args
|
||||||
|
self.num_classes = None
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.samples[idx]
|
item = self.samples[idx]
|
||||||
|
@ -317,6 +318,13 @@ class VitsDataset(TTSDataset):
|
||||||
"""
|
"""
|
||||||
# convert list of dicts to dict of lists
|
# convert list of dicts to dict of lists
|
||||||
B = len(batch)
|
B = len(batch)
|
||||||
|
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
|
||||||
|
if self.model_args.use_perfect_class_batch_sampler:
|
||||||
|
new_batch = []
|
||||||
|
for i in range(self.num_classes):
|
||||||
|
new_batch.extend(batch[i:B:self.num_classes])
|
||||||
|
batch = new_batch
|
||||||
|
|
||||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||||
|
|
||||||
_, ids_sorted_decreasing = torch.sort(
|
_, ids_sorted_decreasing = torch.sort(
|
||||||
|
@ -546,6 +554,9 @@ class VitsArgs(Coqpit):
|
||||||
out_channels: int = 513
|
out_channels: int = 513
|
||||||
spec_segment_size: int = 32
|
spec_segment_size: int = 32
|
||||||
hidden_channels: int = 192
|
hidden_channels: int = 192
|
||||||
|
use_adaptive_weight_text_encoder: bool = False
|
||||||
|
use_perfect_class_batch_sampler: bool = False
|
||||||
|
perfect_class_batch_sampler_key: str = ""
|
||||||
hidden_channels_ffn_text_encoder: int = 768
|
hidden_channels_ffn_text_encoder: int = 768
|
||||||
num_heads_text_encoder: int = 2
|
num_heads_text_encoder: int = 2
|
||||||
num_layers_text_encoder: int = 6
|
num_layers_text_encoder: int = 6
|
||||||
|
@ -660,7 +671,8 @@ class Vits(BaseTTS):
|
||||||
self.args.num_layers_text_encoder,
|
self.args.num_layers_text_encoder,
|
||||||
self.args.kernel_size_text_encoder,
|
self.args.kernel_size_text_encoder,
|
||||||
self.args.dropout_p_text_encoder,
|
self.args.dropout_p_text_encoder,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim if not self.args.use_adaptive_weight_text_encoder else 0,
|
||||||
|
num_adaptive_weight_classes=self.num_languages if self.args.use_adaptive_weight_text_encoder else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.posterior_encoder = PosteriorEncoder(
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
@ -690,7 +702,7 @@ class Vits(BaseTTS):
|
||||||
self.args.dropout_p_duration_predictor,
|
self.args.dropout_p_duration_predictor,
|
||||||
4,
|
4,
|
||||||
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim if not self.args.use_adaptive_weight_text_encoder else 0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.duration_predictor = DurationPredictor(
|
self.duration_predictor = DurationPredictor(
|
||||||
|
@ -699,7 +711,7 @@ class Vits(BaseTTS):
|
||||||
3,
|
3,
|
||||||
self.args.dropout_p_duration_predictor,
|
self.args.dropout_p_duration_predictor,
|
||||||
cond_channels=self.embedded_speaker_dim,
|
cond_channels=self.embedded_speaker_dim,
|
||||||
language_emb_dim=self.embedded_language_dim,
|
language_emb_dim=self.embedded_language_dim if not self.args.use_adaptive_weight_text_encoder else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waveform_decoder = HifiganGenerator(
|
self.waveform_decoder = HifiganGenerator(
|
||||||
|
@ -794,9 +806,11 @@ class Vits(BaseTTS):
|
||||||
if self.args.language_ids_file is not None:
|
if self.args.language_ids_file is not None:
|
||||||
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||||
|
|
||||||
if self.args.use_language_embedding and self.language_manager:
|
if self.language_manager:
|
||||||
print(" > initialization of language-embedding layers.")
|
|
||||||
self.num_languages = self.language_manager.num_languages
|
self.num_languages = self.language_manager.num_languages
|
||||||
|
self.embedded_language_dim = 0
|
||||||
|
if self.args.use_language_embedding:
|
||||||
|
print(" > initialization of language-embedding layers.")
|
||||||
self.embedded_language_dim = self.args.embedded_language_dim
|
self.embedded_language_dim = self.args.embedded_language_dim
|
||||||
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim)
|
||||||
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
torch.nn.init.xavier_uniform_(self.emb_l.weight)
|
||||||
|
@ -1016,7 +1030,7 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, class_id=lid)
|
||||||
|
|
||||||
# posterior encoder
|
# posterior encoder
|
||||||
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g)
|
||||||
|
@ -1122,7 +1136,7 @@ class Vits(BaseTTS):
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, class_id=lid)
|
||||||
|
|
||||||
if durations is None:
|
if durations is None:
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
|
@ -1413,7 +1427,7 @@ class Vits(BaseTTS):
|
||||||
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
||||||
|
|
||||||
# get language id
|
# get language id
|
||||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
if hasattr(self, "language_manager") and (config.use_language_embedding or config.use_adaptive_weight_text_encoder) and language_name is not None:
|
||||||
language_id = self.language_manager.name_to_id[language_name]
|
language_id = self.language_manager.name_to_id[language_name]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -1482,7 +1496,7 @@ class Vits(BaseTTS):
|
||||||
d_vectors = torch.FloatTensor(d_vectors)
|
d_vectors = torch.FloatTensor(d_vectors)
|
||||||
|
|
||||||
# get language ids from language names
|
# get language ids from language names
|
||||||
if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding:
|
if self.language_manager is not None and self.language_manager.name_to_id and (self.args.use_language_embedding or self.args.use_adaptive_weight_text_encoder):
|
||||||
language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]]
|
language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]]
|
||||||
|
|
||||||
if language_ids is not None:
|
if language_ids is not None:
|
||||||
|
@ -1547,6 +1561,21 @@ class Vits(BaseTTS):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
|
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
|
||||||
|
if self.args.use_perfect_class_batch_sampler:
|
||||||
|
batch_size = config.eval_batch_size if is_eval else config.batch_size
|
||||||
|
data_items = dataset.samples
|
||||||
|
classes = [item[self.args.perfect_class_batch_sampler_key] for item in data_items]
|
||||||
|
classes = set(classes)
|
||||||
|
dataset.num_classes = len(classes)
|
||||||
|
batch_sampler = PerfectBatchSampler(
|
||||||
|
dataset_items=data_items,
|
||||||
|
classes=classes,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_classes_in_batch=len(classes),
|
||||||
|
label_key=self.args.perfect_class_batch_sampler_key,
|
||||||
|
)
|
||||||
|
return batch_sampler
|
||||||
|
|
||||||
weights = None
|
weights = None
|
||||||
data_items = dataset.samples
|
data_items = dataset.samples
|
||||||
if getattr(config, "use_weighted_sampler", False):
|
if getattr(config, "use_weighted_sampler", False):
|
||||||
|
@ -1631,7 +1660,7 @@ class Vits(BaseTTS):
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if num_gpus > 1:
|
if num_gpus > 1 and not self.args.use_perfect_class_batch_sampler:
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
|
|
@ -92,7 +92,7 @@ class LanguageManager(BaseIDManager):
|
||||||
config (Coqpit): Coqpit config.
|
config (Coqpit): Coqpit config.
|
||||||
"""
|
"""
|
||||||
language_manager = None
|
language_manager = None
|
||||||
if check_config_and_model_args(config, "use_language_embedding", True):
|
if check_config_and_model_args(config, "use_language_embedding", True) or check_config_and_model_args(config, "use_adaptive_weight_text_encoder", True):
|
||||||
if config.get("language_ids_file", None):
|
if config.get("language_ids_file", None):
|
||||||
language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
language_manager = LanguageManager(language_ids_file_path=config.language_ids_file)
|
||||||
language_manager = LanguageManager(config=config)
|
language_manager = LanguageManager(config=config)
|
||||||
|
|
|
@ -0,0 +1,230 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from trainer import Trainer, TrainerArgs
|
||||||
|
|
||||||
|
from TTS.bin.compute_embeddings import compute_embeddings
|
||||||
|
from TTS.bin.resample import resample_files
|
||||||
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig, VitsDataset
|
||||||
|
from TTS.utils.downloaders import download_libri_tts
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from TTS.utils.samplers import PerfectBatchSampler
|
||||||
|
torch.set_num_threads(24)
|
||||||
|
|
||||||
|
# pylint: disable=W0105
|
||||||
|
"""
|
||||||
|
This recipe replicates the first experiment proposed in the CML-TTS paper (https://arxiv.org/abs/2306.10097). It uses the YourTTS model.
|
||||||
|
YourTTS model is based on the VITS model however it uses external speaker embeddings extracted from a pre-trained speaker encoder and has small architecture changes.
|
||||||
|
"""
|
||||||
|
CURRENT_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# Name of the run for the Trainer
|
||||||
|
RUN_NAME = "YourTTS-CML-TTS"
|
||||||
|
|
||||||
|
# Path where you want to save the models outputs (configs, checkpoints and tensorboard logs)
|
||||||
|
OUT_PATH = os.path.dirname(os.path.abspath(__file__)) # "/raid/coqui/Checkpoints/original-YourTTS/"
|
||||||
|
|
||||||
|
# If you want to do transfer learning and speedup your training you can set here the path to the CML-TTS available checkpoint that cam be downloaded here: https://drive.google.com/u/2/uc?id=1yDCSJ1pFZQTHhL09GMbOrdjcPULApa0p
|
||||||
|
RESTORE_PATH = None # "/raid/edresson/CML_YourTTS/checkpoints_yourtts_cml_tts_dataset/best_model.pth" # Download the checkpoint here: https://drive.google.com/u/2/uc?id=1yDCSJ1pFZQTHhL09GMbOrdjcPULApa0p
|
||||||
|
|
||||||
|
# This paramter is useful to debug, it skips the training epochs and just do the evaluation and produce the test sentences
|
||||||
|
SKIP_TRAIN_EPOCH = False
|
||||||
|
|
||||||
|
# Set here the batch size to be used in training and evaluation
|
||||||
|
BATCH_SIZE = 6
|
||||||
|
|
||||||
|
# Training Sampling rate and the target sampling rate for resampling the downloaded dataset (Note: If you change this you might need to redownload the dataset !!)
|
||||||
|
# Note: If you add new datasets, please make sure that the dataset sampling rate and this parameter are matching, otherwise resample your audios
|
||||||
|
SAMPLE_RATE = 24000
|
||||||
|
|
||||||
|
# Max audio length in seconds to be used in training (every audio bigger than it will be ignored)
|
||||||
|
MAX_AUDIO_LEN_IN_SECONDS = float("inf")
|
||||||
|
|
||||||
|
# DEfine here the datasets config
|
||||||
|
esd_train_config = BaseDatasetConfig(
|
||||||
|
formatter="coqui",
|
||||||
|
dataset_name="esd",
|
||||||
|
meta_file_train="metadata_with_basic_metrics.csv", # TODO: compute emotion and d-vectors for test and evaluation splits
|
||||||
|
path="/raid/datasets/Emotion/ESD-44kHz-VAD-renormalized/",
|
||||||
|
language="en"
|
||||||
|
)
|
||||||
|
|
||||||
|
savee_config = BaseDatasetConfig(
|
||||||
|
formatter="coqui",
|
||||||
|
dataset_name="savee",
|
||||||
|
path="/raid/datasets/SAVEE-44khz/",
|
||||||
|
meta_file_train="metadata_with_basic_metrics.csv",
|
||||||
|
language="pt"
|
||||||
|
)
|
||||||
|
game1_config = BaseDatasetConfig(
|
||||||
|
formatter="coqui",
|
||||||
|
dataset_name="game1",
|
||||||
|
path="/raid/datasets/new_game_data/game1/datasetbuilder_formatted/",
|
||||||
|
meta_file_train="metadata_with_basic_metrics.csv",
|
||||||
|
language="de",
|
||||||
|
)
|
||||||
|
DATASETS_CONFIG_LIST = [esd_train_config, savee_config, game1_config]
|
||||||
|
|
||||||
|
|
||||||
|
### Extract speaker embeddings
|
||||||
|
SPEAKER_ENCODER_CHECKPOINT_PATH = (
|
||||||
|
"https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar"
|
||||||
|
)
|
||||||
|
SPEAKER_ENCODER_CONFIG_PATH = "https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json"
|
||||||
|
|
||||||
|
D_VECTOR_FILES = [] # List of speaker embeddings/d-vectors to be used during the training
|
||||||
|
|
||||||
|
# Iterates all the dataset configs checking if the speakers embeddings are already computated, if not compute it
|
||||||
|
for dataset_conf in DATASETS_CONFIG_LIST:
|
||||||
|
# Check if the embeddings weren't already computed, if not compute it
|
||||||
|
embeddings_file = os.path.join(dataset_conf.path, "H_ASP_speaker_embeddings.pth")
|
||||||
|
if not os.path.isfile(embeddings_file):
|
||||||
|
print(f">>> Computing the speaker embeddings for the {dataset_conf.dataset_name} dataset")
|
||||||
|
compute_embeddings(
|
||||||
|
SPEAKER_ENCODER_CHECKPOINT_PATH,
|
||||||
|
SPEAKER_ENCODER_CONFIG_PATH,
|
||||||
|
embeddings_file,
|
||||||
|
old_speakers_file=None,
|
||||||
|
config_dataset_path=None,
|
||||||
|
formatter_name=dataset_conf.formatter,
|
||||||
|
dataset_name=dataset_conf.dataset_name,
|
||||||
|
dataset_path=dataset_conf.path,
|
||||||
|
meta_file_train=dataset_conf.meta_file_train,
|
||||||
|
meta_file_val=dataset_conf.meta_file_val,
|
||||||
|
disable_cuda=False,
|
||||||
|
no_eval=False,
|
||||||
|
)
|
||||||
|
D_VECTOR_FILES.append(embeddings_file)
|
||||||
|
|
||||||
|
|
||||||
|
# Audio config used in training.
|
||||||
|
audio_config = VitsAudioConfig(
|
||||||
|
sample_rate=SAMPLE_RATE,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
fft_size=1024,
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=None,
|
||||||
|
num_mels=80,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init VITSArgs setting the arguments that are needed for the YourTTS model
|
||||||
|
model_args = VitsArgs(
|
||||||
|
spec_segment_size=62,
|
||||||
|
hidden_channels=192,
|
||||||
|
hidden_channels_ffn_text_encoder=768,
|
||||||
|
num_heads_text_encoder=2,
|
||||||
|
num_layers_text_encoder=10,
|
||||||
|
kernel_size_text_encoder=3,
|
||||||
|
dropout_p_text_encoder=0.1,
|
||||||
|
d_vector_file=D_VECTOR_FILES,
|
||||||
|
use_d_vector_file=True,
|
||||||
|
d_vector_dim=512,
|
||||||
|
speaker_encoder_model_path=SPEAKER_ENCODER_CHECKPOINT_PATH,
|
||||||
|
speaker_encoder_config_path=SPEAKER_ENCODER_CONFIG_PATH,
|
||||||
|
resblock_type_decoder="2", # In the paper, we accidentally trained the YourTTS using ResNet blocks type 2, if you like you can use the ResNet blocks type 1 like the VITS model
|
||||||
|
# Useful parameters to enable the Speaker Consistency Loss (SCL) described in the paper
|
||||||
|
use_speaker_encoder_as_loss=False,
|
||||||
|
# Useful parameters to enable multilingual training
|
||||||
|
use_language_embedding=False,
|
||||||
|
embedded_language_dim=4,
|
||||||
|
use_adaptive_weight_text_encoder=True,
|
||||||
|
use_perfect_class_batch_sampler=True,
|
||||||
|
perfect_class_batch_sampler_key="language"
|
||||||
|
)
|
||||||
|
|
||||||
|
# General training config, here you can change the batch size and others useful parameters
|
||||||
|
config = VitsConfig(
|
||||||
|
output_path=OUT_PATH,
|
||||||
|
model_args=model_args,
|
||||||
|
run_name=RUN_NAME,
|
||||||
|
project_name="SYNTACC",
|
||||||
|
run_description="""
|
||||||
|
- YourTTS with SYNTACC text encoder
|
||||||
|
""",
|
||||||
|
dashboard_logger="tensorboard",
|
||||||
|
logger_uri=None,
|
||||||
|
audio=audio_config,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
batch_group_size=48,
|
||||||
|
eval_batch_size=BATCH_SIZE,
|
||||||
|
num_loader_workers=8,
|
||||||
|
eval_split_max_size=256,
|
||||||
|
print_step=50,
|
||||||
|
plot_step=100,
|
||||||
|
log_model_step=1000,
|
||||||
|
save_step=5000,
|
||||||
|
save_n_checkpoints=2,
|
||||||
|
save_checkpoints=True,
|
||||||
|
# target_loss="loss_1",
|
||||||
|
print_eval=False,
|
||||||
|
use_phonemes=False,
|
||||||
|
phonemizer="espeak",
|
||||||
|
phoneme_language="en",
|
||||||
|
compute_input_seq_cache=True,
|
||||||
|
add_blank=True,
|
||||||
|
text_cleaner="multilingual_cleaners",
|
||||||
|
characters=CharactersConfig(
|
||||||
|
characters_class="TTS.tts.models.vits.VitsCharacters",
|
||||||
|
pad="_",
|
||||||
|
eos="&",
|
||||||
|
bos="*",
|
||||||
|
blank=None,
|
||||||
|
characters="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u00a1\u00a3\u00b7\u00b8\u00c0\u00c1\u00c2\u00c3\u00c4\u00c5\u00c7\u00c8\u00c9\u00ca\u00cb\u00cc\u00cd\u00ce\u00cf\u00d1\u00d2\u00d3\u00d4\u00d5\u00d6\u00d9\u00da\u00db\u00dc\u00df\u00e0\u00e1\u00e2\u00e3\u00e4\u00e5\u00e7\u00e8\u00e9\u00ea\u00eb\u00ec\u00ed\u00ee\u00ef\u00f1\u00f2\u00f3\u00f4\u00f5\u00f6\u00f9\u00fa\u00fb\u00fc\u0101\u0104\u0105\u0106\u0107\u010b\u0119\u0141\u0142\u0143\u0144\u0152\u0153\u015a\u015b\u0161\u0178\u0179\u017a\u017b\u017c\u020e\u04e7\u05c2\u1b20",
|
||||||
|
punctuations="\u2014!'(),-.:;?\u00bf ",
|
||||||
|
phonemes="iy\u0268\u0289\u026fu\u026a\u028f\u028ae\u00f8\u0258\u0259\u0275\u0264o\u025b\u0153\u025c\u025e\u028c\u0254\u00e6\u0250a\u0276\u0251\u0252\u1d7b\u0298\u0253\u01c0\u0257\u01c3\u0284\u01c2\u0260\u01c1\u029bpbtd\u0288\u0256c\u025fk\u0261q\u0262\u0294\u0274\u014b\u0272\u0273n\u0271m\u0299r\u0280\u2c71\u027e\u027d\u0278\u03b2fv\u03b8\u00f0sz\u0283\u0292\u0282\u0290\u00e7\u029dx\u0263\u03c7\u0281\u0127\u0295h\u0266\u026c\u026e\u028b\u0279\u027bj\u0270l\u026d\u028e\u029f\u02c8\u02cc\u02d0\u02d1\u028dw\u0265\u029c\u02a2\u02a1\u0255\u0291\u027a\u0267\u025a\u02de\u026b'\u0303' ",
|
||||||
|
is_unique=True,
|
||||||
|
is_sorted=True,
|
||||||
|
),
|
||||||
|
phoneme_cache_path=None,
|
||||||
|
precompute_num_workers=12,
|
||||||
|
start_by_longest=True,
|
||||||
|
datasets=DATASETS_CONFIG_LIST,
|
||||||
|
cudnn_benchmark=False,
|
||||||
|
max_audio_len=SAMPLE_RATE * MAX_AUDIO_LEN_IN_SECONDS,
|
||||||
|
mixed_precision=False,
|
||||||
|
test_sentences=[
|
||||||
|
["Voc\u00ea ter\u00e1 a vista do topo da montanha que voc\u00ea escalar.", "ESD_0012", None, "pt"],
|
||||||
|
["Quando voc\u00ea n\u00e3o corre nenhum risco, voc\u00ea arrisca tudo.", "ESD_0012", None, "pt"],
|
||||||
|
],
|
||||||
|
# Enable the weighted sampler
|
||||||
|
use_weighted_sampler=True,
|
||||||
|
# Ensures that all speakers are seen in the training batch equally no matter how many samples each speaker has
|
||||||
|
# weighted_sampler_attrs={"language": 1.0, "speaker_name": 1.0},
|
||||||
|
weighted_sampler_attrs={"language": 1.0},
|
||||||
|
weighted_sampler_multipliers={
|
||||||
|
# "speaker_name": {
|
||||||
|
# you can force the batching scheme to give a higher weight to a certain speaker and then this speaker will appears more frequently on the batch.
|
||||||
|
# It will speedup the speaker adaptation process. Considering the CML train dataset and "new_speaker" as the speaker name of the speaker that you want to adapt.
|
||||||
|
# The line above will make the balancer consider the "new_speaker" as 106 speakers so 1/4 of the number of speakers present on CML dataset.
|
||||||
|
# 'new_speaker': 106, # (CML tot. train speaker)/4 = (424/4) = 106
|
||||||
|
# }
|
||||||
|
},
|
||||||
|
# It defines the Speaker Consistency Loss (SCL) α to 9 like the YourTTS paper
|
||||||
|
speaker_encoder_loss_alpha=9.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all the datasets samples and split traning and evaluation sets
|
||||||
|
train_samples, eval_samples = load_tts_samples(
|
||||||
|
config.datasets,
|
||||||
|
eval_split=True,
|
||||||
|
eval_split_max_size=config.eval_split_max_size,
|
||||||
|
eval_split_size=config.eval_split_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Init the model
|
||||||
|
model = Vits.init_from_config(config)
|
||||||
|
|
||||||
|
# Init the trainer and 🚀
|
||||||
|
trainer = Trainer(
|
||||||
|
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=True),
|
||||||
|
config,
|
||||||
|
output_path=OUT_PATH,
|
||||||
|
model=model,
|
||||||
|
train_samples=train_samples,
|
||||||
|
eval_samples=eval_samples,
|
||||||
|
)
|
||||||
|
trainer.fit()
|
Loading…
Reference in New Issue