Add XTTS base training code

This commit is contained in:
Edresson Casanova 2023-10-11 16:09:51 -03:00
parent 1e152692ed
commit a32961bcb4
3 changed files with 1014 additions and 0 deletions

View File

@ -0,0 +1,202 @@
import os
import random
import sys
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.data
import torchaudio
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
torch.set_num_threads(1)
def key_samples_by_col(samples, col):
"""Returns a dictionary of samples keyed by language."""
samples_by_col = {}
for sample in samples:
col_val = sample[col]
assert isinstance(col_val, str)
if col_val not in samples_by_col:
samples_by_col[col_val] = []
samples_by_col[col_val].append(sample)
return samples_by_col
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate):
rel_clip = load_audio(gt_path, sample_rate)
sample_length = random.randint(min_sample_length, max_sample_length)
gap = rel_clip.shape[-1] - sample_length
if gap < 0:
sample_length = rel_clip.shape[-1] // 2
gap = rel_clip.shape[-1] - sample_length
rand_start = random.randint(0, gap)
rand_end = rand_start+sample_length
rel_clip = rel_clip[:, rand_start:rand_end]
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
cond_idxs = [rand_start, rand_end]
return rel_clip, rel_clip.shape[-1], cond_idxs
def load_audio(audiopath, sampling_rate):
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
if audiopath[-4:] == '.mp3':
# it uses torchaudio with sox backend to load mp3
audio, lsr = torchaudio_sox_load(audiopath)
else:
# it uses torchaudio soundfile backend to load all the others data type
audio, lsr = torchaudio_soundfile_load(audiopath)
# stereo to mono if needed
if audio.size(0) != 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if lsr != sampling_rate:
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
if torch.any(audio > 10) or not torch.any(audio < 0):
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
# clip audio invalid values
audio.clip_(-1, 1)
return audio
class XTTSDataset(torch.utils.data.Dataset):
def __init__(self, config, samples, tokenizer, sample_rate):
self.config = config
model_args = config.model_args
self.failed_samples = set()
self.debug_failures = model_args.debug_loading_failures
self.max_conditioning_length = model_args.max_conditioning_length
self.min_conditioning_length = model_args.min_conditioning_length
# self.samples = []
# cache the samples and added type "0" for all samples
# ToDo: find a better way to deal with type
# for item in samples:
# self.samples.append([item['audio_file'], item["text"], 0])
self.samples = samples
random.seed(config.training_seed)
# random.shuffle(self.samples)
random.shuffle(self.samples)
# order by language
self.samples = key_samples_by_col(self.samples, "language")
print(" > Sampling by language:", self.samples.keys())
# use always the output sampling rate to load in the highest quality
self.sample_rate = sample_rate
self.max_wav_len = model_args.max_wav_length
self.max_text_len = model_args.max_text_length
assert self.max_wav_len is not None and self.max_text_len is not None
# load specific vocabulary
self.tokenizer = tokenizer
def get_text(self, text, lang):
tokens = self.tokenizer.encode(text, lang)
tokens = torch.IntTensor(tokens)
assert not torch.any(tokens == 1), f"UNK token found in {text} -> {self.tokenizer.decode(tokens)}"
# The stop token should always be sacred.
assert not torch.any(tokens == 0), f"Stop token found in {text}"
return tokens
def load_item(self, sample):
text = str(sample['text'])
tseq = self.get_text(text, sample["language"])
audiopath = sample['audio_file']
wav = load_audio(audiopath, self.sample_rate)
if text is None or len(text.strip()) == 0:
raise ValueError
if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
# Ultra short clips are also useless (and can cause problems within some models).
raise ValueError
# get a slice from GT to condition the model
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate)
return tseq, audiopath, wav, cond, cond_len, cond_idxs
def __getitem__(self, index):
# select a random language
lang = random.choice(list(self.samples.keys()))
# select random sample
index = random.randint(0, len(self.samples[lang]) - 1)
sample = self.samples[lang][index]
# a unique id for each sampel to deal with fails
sample_id = lang+"_"+str(index)
# ignore samples that we already know that is not valid ones
if sample_id in self.failed_samples:
if self.debug_failures:
print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!")
# call get item again to get other sample
return self[1]
# try to load the sample, if fails added it to the failed samples list
try:
tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample)
except:
if self.debug_failures:
print(f"error loading {sample['audio_file']} {sys.exc_info()}")
self.failed_samples.add(sample_id)
return self[1]
# check if the audio and text size limits and if it out of the limits, added it failed_samples
if wav is None or \
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
if self.debug_failures and wav is not None and tseq is not None:
print(f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
self.failed_samples.add(sample_id)
return self[1]
res = {
# 'real_text': text,
'text': tseq,
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
'wav': wav,
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
'filenames': audiopath,
'conditioning': cond.unsqueeze(1),
'cond_lens': torch.tensor(cond_len, dtype=torch.long),
'cond_idxs': torch.tensor(cond_idxs),
}
return res
def __len__(self):
return sum([len(v) for v in self.samples.values()])
def collate_fn(self, batch):
# convert list of dicts to dict of lists
B = len(batch)
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
# stack for features that already have the same shape
batch["wav_lengths"] = torch.stack(batch["wav_lengths"])
batch["text_lengths"] = torch.stack(batch["text_lengths"])
batch["conditioning"] = torch.stack(batch["conditioning"])
batch["cond_lens"] = torch.stack(batch["cond_lens"])
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
max_text_len = batch["text_lengths"].max()
max_wav_len = batch["wav_lengths"].max()
# create padding tensors
text_padded = torch.IntTensor(B, max_text_len)
wav_padded = torch.FloatTensor(B, 1, max_wav_len)
# initialize tensors for zero padding
text_padded = text_padded.zero_()
wav_padded = wav_padded.zero_()
for i in range(B):
text = batch["text"][i]
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
wav = batch['wav'][i]
wav_padded[i, :, :batch["wav_lengths"][i]] = torch.FloatTensor(wav)
batch["wav"] = wav_padded
batch["padded_text"] = text_padded
return batch

View File

@ -0,0 +1,448 @@
import os
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torchaudio
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import sys
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
from TTS.tts.models.base_tts import BaseTTS
from coqpit import Coqpit
from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
from TTS.tts.datasets.dataset import TTSDataset
from trainer.torch import DistributedSampler
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
from TTS.utils.io import load_fsspec
from TTS.tts.layers.xtts.dvae import DiscreteVAE
@dataclass
class GPTConfig(TortoiseConfig):
lr: float = 5e-06
training_seed: int = 1
optimizer_wd_only_on_weights: bool = False
use_weighted_loss: bool = False # TODO: move it to the base config
weighted_loss_attrs: dict = field(default_factory=lambda: {})
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
@dataclass
class XttsAudioConfig(XttsAudioConfig):
dvae_sample_rate: int = 22050
@dataclass
class GPTArgs(XttsArgs):
min_conditioning_length: int = 66150
max_conditioning_length: int = 132300
gpt_loss_text_ce_weight: float = 0.01
gpt_loss_mel_ce_weight: float = 1.0
gpt_num_audio_tokens: int = 8194
debug_loading_failures: bool = False
max_wav_length: int = 255995 # ~11.6 seconds
max_text_length: int = 200
tokenizer_file: str = ""
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
dvae_checkpoint: str = ""
gpt_checkpoint: str = ""
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
def callback_clearml_load_save(operation_type, model_info):
# return None means skip the file upload/log, returning model_info will continue with the log/upload
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
assert operation_type in ('load', 'save')
# print(operation_type, model_info.__dict__)
if "similarities.pth" in model_info.__dict__['local_model_path']:
return None
return model_info
class GPTTrainer(BaseTTS):
def __init__(self, config: Coqpit):
"""
Tortoise GPT training class
"""
super().__init__(config, ap=None, tokenizer=None)
self.config = config
self.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
self.gpt = GPT(
layers=self.args.gpt_layers,
model_dim=self.args.gpt_n_model_channels,
start_text_token=self.args.gpt_start_text_token,
stop_text_token=self.args.gpt_stop_text_token,
heads=self.args.gpt_n_heads,
max_text_tokens=self.args.gpt_max_text_tokens,
max_mel_tokens=self.args.gpt_max_audio_tokens,
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
number_text_tokens=self.args.gpt_number_text_tokens,
num_audio_tokens=self.args.gpt_num_audio_tokens,
start_audio_token=self.args.gpt_start_audio_token,
stop_audio_token=self.args.gpt_stop_audio_token,
).cuda()
# load GPT if available
if self.args.gpt_checkpoint:
gpt_checkpoint = torch.load(
self.args.gpt_checkpoint, map_location=torch.device("cpu")
)
# deal with coqui Trainer exported model
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
print("Coqui Trainer checkpoint detected! Converting it!")
gpt_checkpoint = gpt_checkpoint["model"]
states_keys = list(gpt_checkpoint.keys())
for key in states_keys:
if "gpt." in key:
new_key = key.replace("gpt.", "")
gpt_checkpoint[new_key] = gpt_checkpoint[key]
del gpt_checkpoint[key]
else:
del gpt_checkpoint[key]
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.gpt.text_embedding.weight.shape:
num_new_tokens = self.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
# add new tokens to a linear layer (text_head)
emb_g = gpt_checkpoint["text_embedding.weight"]
new_row = torch.randn(num_new_tokens, emb_g.shape[1])
start_token_row = emb_g[-1, :]
emb_g = torch.cat([emb_g, new_row], axis=0)
emb_g[-1, :] = start_token_row
gpt_checkpoint["text_embedding.weight"] = emb_g
# add new weights to the linear layer (text_head)
text_head_weight = gpt_checkpoint["text_head.weight"]
start_token_row = text_head_weight[-1, :]
new_entry = torch.randn(num_new_tokens, self.gpt.text_head.weight.shape[1])
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
text_head_weight[-1, :] = start_token_row
gpt_checkpoint["text_head.weight"] = text_head_weight
# add new biases to the linear layer (text_head)
text_head_bias = gpt_checkpoint["text_head.bias"]
start_token_row = text_head_bias[-1]
new_bias_entry = torch.zeros(num_new_tokens)
text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0)
text_head_bias[-1] = start_token_row
gpt_checkpoint["text_head.bias"] = text_head_bias
self.gpt.load_state_dict(gpt_checkpoint, strict=True)
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
else:
print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint")
# Mel spectrogram extractor for conditioning
self.torch_mel_spectrogram = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.sample_rate)
# Load DVAE
self.dvae = DiscreteVAE(
channels=80,
normalization=None,
positional_dims=1,
num_tokens=self.args.gpt_num_audio_tokens - 2,
codebook_dim=512,
hidden_dim=512,
num_resnet_blocks=3,
kernel_size=3,
num_layers=2,
use_transposed_convs=False,
)
self.dvae.eval()
if self.args.dvae_checkpoint:
dvae_checkpoint = torch.load(
self.args.dvae_checkpoint, map_location=torch.device("cpu")
)
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
else:
raise RuntimeError("You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!")
# Mel spectrogram extractor for DVAE
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
@property
def device(self):
return next(self.parameters()).device
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_lengths: long tensor, (b,)
"""
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_lens=cond_lens)
return losses
@torch.no_grad()
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
return {}, {}
def format_batch(self, batch: Dict) -> Dict:
return batch
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
def format_batch_on_device(self, batch):
"""Compute spectrograms on the device."""
batch["text_lengths"] = batch["text_lengths"]
batch["wav_lengths"] = batch["wav_lengths"]
batch["text_inputs"] = batch["padded_text"]
batch["cond_lens"] = batch["cond_lens"]
batch["cond_idxs"] = batch["cond_idxs"]
# compute conditioning mel specs
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
B, num_cond_samples, C, T = batch["conditioning"].size()
conditioning_reshaped = batch["conditioning"].view(B*num_cond_samples, C, T)
paired_conditioning_mel = self.torch_mel_spectrogram(conditioning_reshaped)
# transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel])
n_mel = self.torch_mel_spectrogram.n_mel_channels # paired_conditioning_mel.size(1)
T_mel = paired_conditioning_mel.size(2)
paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel)
# get the conditioning embeddings
batch["cond_mels"] = paired_conditioning_mel
# compute codes using DVAE
if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate:
dvae_wav = torchaudio.functional.resample(
batch["wav"],
orig_freq=self.config.audio.sample_rate,
new_freq=self.config.audio.dvae_sample_rate,
lowpass_filter_width=64,
rolloff=0.9475937167399596,
resampling_method="kaiser_window",
beta=14.769656459379492,
)
else:
dvae_wav = batch["wav"]
dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav)
codes = self.dvae.get_codebook_indices(dvae_mel_spec)
batch["audio_codes"] = codes
# delete useless batch tensors
del batch["padded_text"]
del batch["wav"]
del batch["conditioning"]
return batch
def train_step(self, batch, criterion):
loss_dict = {}
cond_mels = batch["cond_mels"]
text_inputs = batch["text_inputs"]
text_lengths = batch["text_lengths"]
audio_codes = batch["audio_codes"]
wav_lengths = batch["wav_lengths"]
cond_lens=batch["cond_lens"]
# Todo: implement masking on the cond slice
cond_idxs = batch["cond_idxs"]
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens)
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
return {"model_outputs": None}, loss_dict
def eval_step(self, batch, criterion):
return self.train_step(batch, criterion)
def on_epoch_start(self, trainer): # pylint: disable=W0613
# guarante that dvae will be in eval mode after .train() on evaluation end
self.dvae = self.dvae.eval()
def on_init_end(self, trainer): # pylint: disable=W0613
# ignore similarities.pth on clearml save/upload
if self.config.dashboard_logger.lower() == "clearml":
from clearml.binding.frameworks import WeightsFileHandler
WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
@torch.no_grad()
def inference(
self,
x,
aux_input=None,
): # pylint: disable=dangerous-default-value
return None
@staticmethod
def get_criterion():
return None
def get_sampler(self, dataset: TTSDataset, num_gpus=1):
# sampler for DDP
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
return batch_sampler
def get_data_loader(
self,
config: Coqpit,
assets: Dict,
is_eval: bool,
samples: Union[List[Dict], List[List]],
verbose: bool,
num_gpus: int,
rank: int = None,
) -> "DataLoader": # pylint: disable=W0613
if is_eval and not config.run_eval:
loader = None
else:
# Todo: remove the randomness of dataset when it is eval
# init dataloader
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate)
# wait all the DDP process to be ready
if num_gpus > 1:
torch.distributed.barrier()
# sort input sequences from short to long
# dataset.preprocess_samples()
# get samplers
sampler = self.get_sampler(dataset, num_gpus)
# ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs
if sampler is None or is_eval:
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False,
drop_last=False,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
else:
loader = DataLoader(
dataset,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
)
return loader
def get_optimizer(self) -> List:
"""Initiate and return the optimizer based on the config parameters.
"""
# ToDo: deal with multi GPU training
if self.config.optimizer_wd_only_on_weights:
# parameters to only GPT model
net = self.gpt
# normalizations
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
# nn.Embedding
emb_modules = (nn.Embedding, nn.EmbeddingBag)
param_names_notweights = set()
all_param_names = set()
param_map = {}
for mn, m in net.named_modules():
for k, v in m.named_parameters():
v.is_bias = k.endswith(".bias")
v.is_weight = k.endswith(".weight")
v.is_norm = isinstance(m, norm_modules)
v.is_emb = isinstance(m, emb_modules)
fpn = '%s.%s' % (mn, k) if mn else k # full param name
all_param_names.add(fpn)
param_map[fpn] = v
if v.is_bias or v.is_norm or v.is_emb:
param_names_notweights.add(fpn)
params_names_notweights = sorted(list(param_names_notweights))
params_notweights = [param_map[k] for k in params_names_notweights]
params_names_weights = sorted(list(all_param_names ^ param_names_notweights))
params_weights = [param_map[k] for k in params_names_weights]
groups = [
{ 'params': params_weights, 'weight_decay': self.config.optimizer_params["weight_decay"]},
{ 'params': params_notweights, 'weight_decay': 0}
]
# torch.optim.AdamW
opt = get_optimizer(
self.config.optimizer,
self.config.optimizer_params,
self.config.lr,
parameters=groups,
)
opt._group_names = [params_names_weights, params_names_notweights]
return opt
return get_optimizer(
self.config.optimizer,
self.config.optimizer_params,
self.config.lr,
# optimize only for the GPT model
parameters=self.gpt.parameters(),
)
def get_scheduler(self, optimizer) -> List:
"""Set the scheduler for the optimizer.
Args:
optimizer: `torch.optim.Optimizer`.
"""
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
def load_checkpoint(
self,
config,
checkpoint_path,
eval=False,
strict=True,
cache_storage="/tmp/tts_cache",
target_protocol="s3",
target_options={"anon": True},
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# load the model weights
self.gpt.load_state_dict(state, strict=strict)
if eval:
self.eval()
self.set_inference()
assert not self.training
@staticmethod
def init_from_config(config: "GPTConfig", samples: Union[List[List], List[Dict]] = None):
"""Initiate model from config
Args:
config (GPTConfig): Model config.
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
Defaults to None.
"""
return GPTTrainer(config)

View File

@ -0,0 +1,364 @@
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTConfig
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/MLS/mls_german",
meta_file_train="metadata_train_with_previous_audio_key.csv",
language="de",
)
config_coqui_MLS_metadata_test_with_previous_audio_key_de = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/MLS/mls_german",
meta_file_train="metadata_test_with_previous_audio_key.csv",
language="de",
)
config_coqui_MLS_metadata_dev_with_previous_audio_key_de = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/MLS/mls_german",
meta_file_train="metadata_dev_with_previous_audio_key.csv",
language="de",
)
config_coqui_mls_french_metadata_with_previous_audio_key_fr = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/MLS/mls_french/",
meta_file_train="metadata_with_previous_audio_key.csv",
language="fr",
)
config_coqui_mls_spanish_metadata_with_previous_audio_key_es = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/MLS/mls_spanish/",
meta_file_train="/raid/datasets/MLS/mls_spanish/metadata_with_previous_audio_key.csv",
language="es",
)
config_coqui_mls_italian_metadata_with_previous_audio_key_it = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/MLS/mls_italian/",
meta_file_train="/raid/datasets/MLS/mls_italian/metadata_with_previous_audio_key.csv",
language="it",
)
config_coqui_mls_portuguese_metadata_with_previous_audio_key_pt = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/MLS/mls_portuguese/",
meta_file_train="/raid/datasets/MLS/mls_portuguese/metadata_with_previous_audio_key.csv",
language="pt",
)
config_coqui_mls_polish_metadata_with_previous_audio_key_pl = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/MLS/mls_polish/",
meta_file_train="/raid/datasets/MLS/mls_polish/metadata_with_previous_audio_key.csv",
language="pl",
)
config_coqui_common_voice_metafile_it_train_with_scores_it = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_it_train_with_scores.csv",
language="it",
)
config_coqui_common_voice_metafile_it_test_with_scores_it = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_it_test_with_scores.csv",
language="it",
)
config_coqui_common_voice_metafile_it_dev_with_scores_it = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_it_dev_with_scores.csv",
language="it",
)
config_coqui_common_voice_metafile_pt_train_with_scores_pt = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_pt_train_with_scores.csv",
language="pt",
)
config_coqui_common_voice_metafile_pt_test_with_scores_pt = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_pt_test_with_scores.csv",
language="pt",
)
config_coqui_common_voice_metafile_pt_dev_with_scores_pt = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_pt_dev_with_scores.csv",
language="pt",
)
config_coqui_common_voice_metafile_en_train_en = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_en_train.csv",
language="en",
)
config_coqui_common_voice_metafile_en_test_en = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_en_test.csv",
language="en",
)
config_coqui_common_voice_metafile_en_dev_en = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_en_dev.csv",
language="en",
)
config_coqui_common_voice_metafile_tr_validated_tr = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_tr_validated.csv",
language="tr",
)
config_coqui_common_voice_metafile_ru_validated_ru = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_ru_validated.csv",
language="ru",
)
config_coqui_common_voice_metafile_nl_validated_nl = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_nl_validated.csv",
language="nl",
)
config_coqui_common_voice_metafile_cs_validated_cs = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_cs_validated.csv",
language="cs",
)
config_coqui_common_voice_metafile_fr_validated_fr = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_fr_validated.csv",
language="fr",
)
config_coqui_common_voice_metafile_es_validated_es = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_es_validated.csv",
language="es",
)
config_coqui_common_voice_metafile_pl_validated_pl = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_pl_validated.csv",
language="pl",
)
config_coqui_common_voice_metafile_ar_validated_ar = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_ar_validated.csv",
language="ar",
)
config_coqui_common_voice_metafile_zh_CN_validated_zh_cn = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_zh-CN_validated.csv",
language="zh-cn",
)
config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
formatter="coqui",
dataset_name="coqui",
path="/raid/datasets/common_voice/",
meta_file_train="/raid/datasets/common_voice/metafile_ja_validated.csv",
language="ja",
)
# DATASETS_CONFIG_LIST=[config_coqui_MLS_metadata_train_with_previous_audio_key_de, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_MLS_metadata_dev_with_previous_audio_key_de, config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it, config_coqui_mls_portuguese_metadata_with_previous_audio_key_pt, config_coqui_mls_polish_metadata_with_previous_audio_key_pl, config_coqui_common_voice_metafile_it_train_with_scores_it, config_coqui_common_voice_metafile_it_test_with_scores_it, config_coqui_common_voice_metafile_it_dev_with_scores_it, config_coqui_common_voice_metafile_pt_train_with_scores_pt, config_coqui_common_voice_metafile_pt_test_with_scores_pt, config_coqui_common_voice_metafile_pt_dev_with_scores_pt, config_coqui_common_voice_metafile_en_train_en, config_coqui_common_voice_metafile_en_test_en, config_coqui_common_voice_metafile_en_dev_en, config_coqui_common_voice_metafile_tr_validated_tr, config_coqui_common_voice_metafile_ru_validated_ru, config_coqui_common_voice_metafile_nl_validated_nl, config_coqui_common_voice_metafile_cs_validated_cs, config_coqui_common_voice_metafile_fr_validated_fr, config_coqui_common_voice_metafile_es_validated_es, config_coqui_common_voice_metafile_pl_validated_pl, config_coqui_common_voice_metafile_ar_validated_ar, config_coqui_common_voice_metafile_zh_CN_validated_zh_cn, config_coqui_common_voice_metafile_ja_validated_ja]
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
def freeze_layers(trainer):
pass
def main():
# init args and config
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=True,
max_wav_length=255995, # ~11.6 seconds
max_text_length=200,
tokenizer_file="/raid/datasets/xtts_models/vocab.json",
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
gpt_checkpoint="/raid/datasets/xtts_models/gpt.pth",
gpt_num_audio_tokens=8194,
gpt_start_audio_token=8192,
gpt_stop_audio_token=8193,
)
audio_config = XttsAudioConfig(
sample_rate=22050, # autoregressive SR
dvae_sample_rate=22050,
diffusion_sample_rate=24000,
output_sample_rate=24000
)
config = GPTConfig(
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
project_name=PROJECT_NAME,
run_description="""
GPT XTTS training
""",
dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI,
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=8,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=1000,
save_step=10000,
save_n_checkpoints=1,
save_checkpoints=True,
# target_loss="loss",
print_eval=False,
# Optimizer values like tortoise. However, they used pytorch implementation with modifications to not apply WD to non-weight parameters. We are using default Pytorch
optimizer="AdamW",
optimizer_wd_only_on_weights=True,
optimizer_params={"betas": [.9, .96], "eps": 1e-8, "weight_decay": 1e-2},
lr=5e-06, # learning rate
# lr=1e-4, # learning rate
# ToDo: implement 500 step warmup like tortoise and EMA weights replaces LR decay with rate: .999
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
)
# init the model from config
model = GPTTrainer.init_from_config(config)
# load training samples
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=START_WITH_EVAL, grad_accum_steps=GRAD_ACUMM_STEPS),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
callbacks={"on_epoch_start": freeze_layers}
)
trainer.fit()
if __name__ == "__main__":
RUN_NAME = "GPT_XTTS"
PROJECT_NAME = "XTTS"
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/"
DASHBOARD_LOGGER = "clearml"
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
RESTORE_PATH = None
SKIP_TRAIN_EPOCH = False
START_WITH_EVAL = True
BATCH_SIZE = 9
GRAD_ACUMM_STEPS = 28
# debug
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
RESTORE_PATH = None
BATCH_SIZE = 2
GRAD_ACUMM_STEPS = 1
NUM_LOADERS = 1
main()