From a32961bcb441c955ca9d7df879becb7c84c2ef52 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 11 Oct 2023 16:09:51 -0300 Subject: [PATCH] Add XTTS base training code --- TTS/tts/layers/xtts/trainer/dataset.py | 202 ++++++++++ TTS/tts/layers/xtts/trainer/gpt_trainer.py | 448 +++++++++++++++++++++ recipes/multilingual/xtts_v1/train_xtts.py | 364 +++++++++++++++++ 3 files changed, 1014 insertions(+) create mode 100644 TTS/tts/layers/xtts/trainer/dataset.py create mode 100644 TTS/tts/layers/xtts/trainer/gpt_trainer.py create mode 100644 recipes/multilingual/xtts_v1/train_xtts.py diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py new file mode 100644 index 00000000..9736ae6c --- /dev/null +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -0,0 +1,202 @@ +import os +import random +import sys +import numpy as np + +import torch +import torch.nn.functional as F +import torch.utils.data +import torchaudio +from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load +from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load +torch.set_num_threads(1) + +def key_samples_by_col(samples, col): + """Returns a dictionary of samples keyed by language.""" + samples_by_col = {} + for sample in samples: + col_val = sample[col] + assert isinstance(col_val, str) + if col_val not in samples_by_col: + samples_by_col[col_val] = [] + samples_by_col[col_val].append(sample) + return samples_by_col + + +def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate): + rel_clip = load_audio(gt_path, sample_rate) + sample_length = random.randint(min_sample_length, max_sample_length) + gap = rel_clip.shape[-1] - sample_length + if gap < 0: + sample_length = rel_clip.shape[-1] // 2 + gap = rel_clip.shape[-1] - sample_length + rand_start = random.randint(0, gap) + rand_end = rand_start+sample_length + rel_clip = rel_clip[:, rand_start:rand_end] + rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1])) + cond_idxs = [rand_start, rand_end] + return rel_clip, rel_clip.shape[-1], cond_idxs + + +def load_audio(audiopath, sampling_rate): + # better load setting following: https://github.com/faroit/python_audio_loading_benchmark + if audiopath[-4:] == '.mp3': + # it uses torchaudio with sox backend to load mp3 + audio, lsr = torchaudio_sox_load(audiopath) + else: + # it uses torchaudio soundfile backend to load all the others data type + audio, lsr = torchaudio_soundfile_load(audiopath) + + # stereo to mono if needed + if audio.size(0) != 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 10) or not torch.any(audio < 0): + print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}") + # clip audio invalid values + audio.clip_(-1, 1) + return audio + +class XTTSDataset(torch.utils.data.Dataset): + def __init__(self, config, samples, tokenizer, sample_rate): + self.config = config + model_args = config.model_args + self.failed_samples = set() + self.debug_failures = model_args.debug_loading_failures + self.max_conditioning_length = model_args.max_conditioning_length + self.min_conditioning_length = model_args.min_conditioning_length + + # self.samples = [] + # cache the samples and added type "0" for all samples + # ToDo: find a better way to deal with type + # for item in samples: + # self.samples.append([item['audio_file'], item["text"], 0]) + self.samples = samples + random.seed(config.training_seed) + # random.shuffle(self.samples) + random.shuffle(self.samples) + # order by language + self.samples = key_samples_by_col(self.samples, "language") + print(" > Sampling by language:", self.samples.keys()) + + # use always the output sampling rate to load in the highest quality + self.sample_rate = sample_rate + self.max_wav_len = model_args.max_wav_length + self.max_text_len = model_args.max_text_length + assert self.max_wav_len is not None and self.max_text_len is not None + + # load specific vocabulary + self.tokenizer = tokenizer + + def get_text(self, text, lang): + tokens = self.tokenizer.encode(text, lang) + tokens = torch.IntTensor(tokens) + assert not torch.any(tokens == 1), f"UNK token found in {text} -> {self.tokenizer.decode(tokens)}" + # The stop token should always be sacred. + assert not torch.any(tokens == 0), f"Stop token found in {text}" + return tokens + + def load_item(self, sample): + text = str(sample['text']) + tseq = self.get_text(text, sample["language"]) + audiopath = sample['audio_file'] + wav = load_audio(audiopath, self.sample_rate) + if text is None or len(text.strip()) == 0: + raise ValueError + if wav is None or wav.shape[-1] < (0.5 * self.sample_rate): + # Ultra short clips are also useless (and can cause problems within some models). + raise ValueError + + # get a slice from GT to condition the model + cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate) + + return tseq, audiopath, wav, cond, cond_len, cond_idxs + + def __getitem__(self, index): + # select a random language + lang = random.choice(list(self.samples.keys())) + # select random sample + index = random.randint(0, len(self.samples[lang]) - 1) + sample = self.samples[lang][index] + # a unique id for each sampel to deal with fails + sample_id = lang+"_"+str(index) + + # ignore samples that we already know that is not valid ones + if sample_id in self.failed_samples: + if self.debug_failures: + print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!") + # call get item again to get other sample + return self[1] + + # try to load the sample, if fails added it to the failed samples list + try: + tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample) + except: + if self.debug_failures: + print(f"error loading {sample['audio_file']} {sys.exc_info()}") + self.failed_samples.add(sample_id) + return self[1] + + # check if the audio and text size limits and if it out of the limits, added it failed_samples + if wav is None or \ + (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \ + (self.max_text_len is not None and tseq.shape[0] > self.max_text_len): + # Basically, this audio file is nonexistent or too long to be supported by the dataset. + # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. + if self.debug_failures and wav is not None and tseq is not None: + print(f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}") + self.failed_samples.add(sample_id) + return self[1] + + res = { + # 'real_text': text, + 'text': tseq, + 'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long), + 'wav': wav, + 'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long), + 'filenames': audiopath, + 'conditioning': cond.unsqueeze(1), + 'cond_lens': torch.tensor(cond_len, dtype=torch.long), + 'cond_idxs': torch.tensor(cond_idxs), + } + return res + + def __len__(self): + return sum([len(v) for v in self.samples.values()]) + + def collate_fn(self, batch): + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + # stack for features that already have the same shape + batch["wav_lengths"] = torch.stack(batch["wav_lengths"]) + batch["text_lengths"] = torch.stack(batch["text_lengths"]) + batch["conditioning"] = torch.stack(batch["conditioning"]) + batch["cond_lens"] = torch.stack(batch["cond_lens"]) + batch["cond_idxs"] = torch.stack(batch["cond_idxs"]) + max_text_len = batch["text_lengths"].max() + max_wav_len = batch["wav_lengths"].max() + + # create padding tensors + text_padded = torch.IntTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, max_wav_len) + + # initialize tensors for zero padding + text_padded = text_padded.zero_() + wav_padded = wav_padded.zero_() + for i in range(B): + text = batch["text"][i] + text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text) + wav = batch['wav'][i] + wav_padded[i, :, :batch["wav_lengths"][i]] = torch.FloatTensor(wav) + + batch["wav"] = wav_padded + batch["padded_text"] = text_padded + + return batch diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py new file mode 100644 index 00000000..f73aeb05 --- /dev/null +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -0,0 +1,448 @@ +import os +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torchaudio +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.data import DataLoader +import sys + + +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.xtts.gpt import GPT +from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + +from TTS.tts.models.base_tts import BaseTTS +from coqpit import Coqpit + +from TTS.tts.configs.tortoise_config import TortoiseConfig +from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram + +from TTS.tts.datasets.dataset import TTSDataset + +from trainer.torch import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler + + +from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset +from TTS.utils.io import load_fsspec + +from TTS.tts.layers.xtts.dvae import DiscreteVAE + +@dataclass +class GPTConfig(TortoiseConfig): + lr: float = 5e-06 + training_seed: int = 1 + optimizer_wd_only_on_weights: bool = False + use_weighted_loss: bool = False # TODO: move it to the base config + weighted_loss_attrs: dict = field(default_factory=lambda: {}) + weighted_loss_multipliers: dict = field(default_factory=lambda: {}) + + +@dataclass +class XttsAudioConfig(XttsAudioConfig): + dvae_sample_rate: int = 22050 + + +@dataclass +class GPTArgs(XttsArgs): + min_conditioning_length: int = 66150 + max_conditioning_length: int = 132300 + gpt_loss_text_ce_weight: float = 0.01 + gpt_loss_mel_ce_weight: float = 1.0 + gpt_num_audio_tokens: int = 8194 + debug_loading_failures: bool = False + max_wav_length: int = 255995 # ~11.6 seconds + max_text_length: int = 200 + tokenizer_file: str = "" + mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth" + dvae_checkpoint: str = "" + gpt_checkpoint: str = "" + vocoder: str = "" # overide vocoder key on the config to avoid json write issues + + +def callback_clearml_load_save(operation_type, model_info): + # return None means skip the file upload/log, returning model_info will continue with the log/upload + # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size + assert operation_type in ('load', 'save') + # print(operation_type, model_info.__dict__) + + if "similarities.pth" in model_info.__dict__['local_model_path']: + return None + + return model_info + +class GPTTrainer(BaseTTS): + def __init__(self, config: Coqpit): + """ + Tortoise GPT training class + """ + super().__init__(config, ap=None, tokenizer=None) + self.config = config + + self.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) + + self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size() + self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") + self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") + + self.gpt = GPT( + layers=self.args.gpt_layers, + model_dim=self.args.gpt_n_model_channels, + start_text_token=self.args.gpt_start_text_token, + stop_text_token=self.args.gpt_stop_text_token, + heads=self.args.gpt_n_heads, + max_text_tokens=self.args.gpt_max_text_tokens, + max_mel_tokens=self.args.gpt_max_audio_tokens, + max_prompt_tokens=self.args.gpt_max_prompt_tokens, + number_text_tokens=self.args.gpt_number_text_tokens, + num_audio_tokens=self.args.gpt_num_audio_tokens, + start_audio_token=self.args.gpt_start_audio_token, + stop_audio_token=self.args.gpt_stop_audio_token, + ).cuda() + + + # load GPT if available + if self.args.gpt_checkpoint: + gpt_checkpoint = torch.load( + self.args.gpt_checkpoint, map_location=torch.device("cpu") + ) + # deal with coqui Trainer exported model + if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): + print("Coqui Trainer checkpoint detected! Converting it!") + gpt_checkpoint = gpt_checkpoint["model"] + states_keys = list(gpt_checkpoint.keys()) + for key in states_keys: + if "gpt." in key: + new_key = key.replace("gpt.", "") + gpt_checkpoint[new_key] = gpt_checkpoint[key] + del gpt_checkpoint[key] + else: + del gpt_checkpoint[key] + + # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible + if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.gpt.text_embedding.weight.shape: + num_new_tokens = self.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] + print(f" > Loading checkpoint with {num_new_tokens} additional tokens.") + + # add new tokens to a linear layer (text_head) + emb_g = gpt_checkpoint["text_embedding.weight"] + new_row = torch.randn(num_new_tokens, emb_g.shape[1]) + start_token_row = emb_g[-1, :] + emb_g = torch.cat([emb_g, new_row], axis=0) + emb_g[-1, :] = start_token_row + gpt_checkpoint["text_embedding.weight"] = emb_g + + # add new weights to the linear layer (text_head) + text_head_weight = gpt_checkpoint["text_head.weight"] + start_token_row = text_head_weight[-1, :] + new_entry = torch.randn(num_new_tokens, self.gpt.text_head.weight.shape[1]) + text_head_weight = torch.cat([text_head_weight, new_entry], axis=0) + text_head_weight[-1, :] = start_token_row + gpt_checkpoint["text_head.weight"] = text_head_weight + + # add new biases to the linear layer (text_head) + text_head_bias = gpt_checkpoint["text_head.bias"] + start_token_row = text_head_bias[-1] + new_bias_entry = torch.zeros(num_new_tokens) + text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0) + text_head_bias[-1] = start_token_row + gpt_checkpoint["text_head.bias"] = text_head_bias + + self.gpt.load_state_dict(gpt_checkpoint, strict=True) + print(">> GPT weights restored from:", self.args.gpt_checkpoint) + else: + print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint") + + # Mel spectrogram extractor for conditioning + self.torch_mel_spectrogram = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.sample_rate) + + # Load DVAE + self.dvae = DiscreteVAE( + channels=80, + normalization=None, + positional_dims=1, + num_tokens=self.args.gpt_num_audio_tokens - 2, + codebook_dim=512, + hidden_dim=512, + num_resnet_blocks=3, + kernel_size=3, + num_layers=2, + use_transposed_convs=False, + ) + + self.dvae.eval() + if self.args.dvae_checkpoint: + dvae_checkpoint = torch.load( + self.args.dvae_checkpoint, map_location=torch.device("cpu") + ) + self.dvae.load_state_dict(dvae_checkpoint, strict=False) + print(">> DVAE weights restored from:", self.args.dvae_checkpoint) + else: + raise RuntimeError("You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!") + + # Mel spectrogram extractor for DVAE + self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + cond_mels: MEL float tensor, (b, num_samples, 80,t_m) + cond_lengths: long tensor, (b,) + """ + losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_lens=cond_lens) + return losses + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 + return {}, {} + + def format_batch(self, batch: Dict) -> Dict: + return batch + + @torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + batch["text_lengths"] = batch["text_lengths"] + batch["wav_lengths"] = batch["wav_lengths"] + batch["text_inputs"] = batch["padded_text"] + batch["cond_lens"] = batch["cond_lens"] + batch["cond_idxs"] = batch["cond_idxs"] + # compute conditioning mel specs + # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor + B, num_cond_samples, C, T = batch["conditioning"].size() + conditioning_reshaped = batch["conditioning"].view(B*num_cond_samples, C, T) + paired_conditioning_mel = self.torch_mel_spectrogram(conditioning_reshaped) + # transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel]) + n_mel = self.torch_mel_spectrogram.n_mel_channels # paired_conditioning_mel.size(1) + T_mel = paired_conditioning_mel.size(2) + paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel) + # get the conditioning embeddings + batch["cond_mels"] = paired_conditioning_mel + # compute codes using DVAE + if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate: + dvae_wav = torchaudio.functional.resample( + batch["wav"], + orig_freq=self.config.audio.sample_rate, + new_freq=self.config.audio.dvae_sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method="kaiser_window", + beta=14.769656459379492, + ) + else: + dvae_wav = batch["wav"] + dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav) + codes = self.dvae.get_codebook_indices(dvae_mel_spec) + batch["audio_codes"] = codes + # delete useless batch tensors + del batch["padded_text"] + del batch["wav"] + del batch["conditioning"] + + return batch + + def train_step(self, batch, criterion): + loss_dict = {} + cond_mels = batch["cond_mels"] + text_inputs = batch["text_inputs"] + text_lengths = batch["text_lengths"] + audio_codes = batch["audio_codes"] + wav_lengths = batch["wav_lengths"] + + cond_lens=batch["cond_lens"] + # Todo: implement masking on the cond slice + cond_idxs = batch["cond_idxs"] + + loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens) + loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight + loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight + loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] + return {"model_outputs": None}, loss_dict + + def eval_step(self, batch, criterion): + return self.train_step(batch, criterion) + + def on_epoch_start(self, trainer): # pylint: disable=W0613 + # guarante that dvae will be in eval mode after .train() on evaluation end + self.dvae = self.dvae.eval() + + def on_init_end(self, trainer): # pylint: disable=W0613 + # ignore similarities.pth on clearml save/upload + if self.config.dashboard_logger.lower() == "clearml": + from clearml.binding.frameworks import WeightsFileHandler + WeightsFileHandler.add_pre_callback(callback_clearml_load_save) + + @torch.no_grad() + def inference( + self, + x, + aux_input=None, + ): # pylint: disable=dangerous-default-value + return None + + @staticmethod + def get_criterion(): + return None + + def get_sampler(self, dataset: TTSDataset, num_gpus=1): + # sampler for DDP + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + return batch_sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": # pylint: disable=W0613 + if is_eval and not config.run_eval: + loader = None + else: + # Todo: remove the randomness of dataset when it is eval + # init dataloader + dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate) + + # wait all the DDP process to be ready + if num_gpus > 1: + torch.distributed.barrier() + + # sort input sequences from short to long + # dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(dataset, num_gpus) + + # ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs + if sampler is None or is_eval: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, + drop_last=False, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_optimizer(self) -> List: + """Initiate and return the optimizer based on the config parameters. + """ + # ToDo: deal with multi GPU training + if self.config.optimizer_wd_only_on_weights: + # parameters to only GPT model + net = self.gpt + + # normalizations + norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, + nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) + # nn.Embedding + emb_modules = (nn.Embedding, nn.EmbeddingBag) + + param_names_notweights = set() + all_param_names = set() + param_map = {} + for mn, m in net.named_modules(): + for k, v in m.named_parameters(): + v.is_bias = k.endswith(".bias") + v.is_weight = k.endswith(".weight") + v.is_norm = isinstance(m, norm_modules) + v.is_emb = isinstance(m, emb_modules) + + fpn = '%s.%s' % (mn, k) if mn else k # full param name + all_param_names.add(fpn) + param_map[fpn] = v + if v.is_bias or v.is_norm or v.is_emb: + param_names_notweights.add(fpn) + + params_names_notweights = sorted(list(param_names_notweights)) + params_notweights = [param_map[k] for k in params_names_notweights] + params_names_weights = sorted(list(all_param_names ^ param_names_notweights)) + params_weights = [param_map[k] for k in params_names_weights] + + groups = [ + { 'params': params_weights, 'weight_decay': self.config.optimizer_params["weight_decay"]}, + { 'params': params_notweights, 'weight_decay': 0} + ] + # torch.optim.AdamW + opt = get_optimizer( + self.config.optimizer, + self.config.optimizer_params, + self.config.lr, + parameters=groups, + ) + opt._group_names = [params_names_weights, params_names_notweights] + return opt + + return get_optimizer( + self.config.optimizer, + self.config.optimizer_params, + self.config.lr, + # optimize only for the GPT model + parameters=self.gpt.parameters(), + ) + + def get_scheduler(self, optimizer) -> List: + """Set the scheduler for the optimizer. + + Args: + optimizer: `torch.optim.Optimizer`. + """ + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) + + def load_checkpoint( + self, + config, + checkpoint_path, + eval=False, + strict=True, + cache_storage="/tmp/tts_cache", + target_protocol="s3", + target_options={"anon": True}, + ): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + # load the model weights + self.gpt.load_state_dict(state, strict=strict) + + if eval: + self.eval() + self.set_inference() + assert not self.training + + @staticmethod + def init_from_config(config: "GPTConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (GPTConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + return GPTTrainer(config) + \ No newline at end of file diff --git a/recipes/multilingual/xtts_v1/train_xtts.py b/recipes/multilingual/xtts_v1/train_xtts.py new file mode 100644 index 00000000..fc2b5d8a --- /dev/null +++ b/recipes/multilingual/xtts_v1/train_xtts.py @@ -0,0 +1,364 @@ +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples + +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTConfig + + +config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/MLS/mls_german", + meta_file_train="metadata_train_with_previous_audio_key.csv", + language="de", +) + + +config_coqui_MLS_metadata_test_with_previous_audio_key_de = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/MLS/mls_german", + meta_file_train="metadata_test_with_previous_audio_key.csv", + language="de", +) + + +config_coqui_MLS_metadata_dev_with_previous_audio_key_de = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/MLS/mls_german", + meta_file_train="metadata_dev_with_previous_audio_key.csv", + language="de", +) + + +config_coqui_mls_french_metadata_with_previous_audio_key_fr = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/MLS/mls_french/", + meta_file_train="metadata_with_previous_audio_key.csv", + language="fr", +) + + +config_coqui_mls_spanish_metadata_with_previous_audio_key_es = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/MLS/mls_spanish/", + meta_file_train="/raid/datasets/MLS/mls_spanish/metadata_with_previous_audio_key.csv", + language="es", +) + + +config_coqui_mls_italian_metadata_with_previous_audio_key_it = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/MLS/mls_italian/", + meta_file_train="/raid/datasets/MLS/mls_italian/metadata_with_previous_audio_key.csv", + language="it", +) + + +config_coqui_mls_portuguese_metadata_with_previous_audio_key_pt = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/MLS/mls_portuguese/", + meta_file_train="/raid/datasets/MLS/mls_portuguese/metadata_with_previous_audio_key.csv", + language="pt", +) + + +config_coqui_mls_polish_metadata_with_previous_audio_key_pl = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/MLS/mls_polish/", + meta_file_train="/raid/datasets/MLS/mls_polish/metadata_with_previous_audio_key.csv", + language="pl", +) + + +config_coqui_common_voice_metafile_it_train_with_scores_it = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_it_train_with_scores.csv", + language="it", +) + + +config_coqui_common_voice_metafile_it_test_with_scores_it = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_it_test_with_scores.csv", + language="it", +) + + +config_coqui_common_voice_metafile_it_dev_with_scores_it = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_it_dev_with_scores.csv", + language="it", +) + + +config_coqui_common_voice_metafile_pt_train_with_scores_pt = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_pt_train_with_scores.csv", + language="pt", +) + + +config_coqui_common_voice_metafile_pt_test_with_scores_pt = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_pt_test_with_scores.csv", + language="pt", +) + + +config_coqui_common_voice_metafile_pt_dev_with_scores_pt = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_pt_dev_with_scores.csv", + language="pt", +) + + +config_coqui_common_voice_metafile_en_train_en = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_en_train.csv", + language="en", +) + + +config_coqui_common_voice_metafile_en_test_en = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_en_test.csv", + language="en", +) + + +config_coqui_common_voice_metafile_en_dev_en = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_en_dev.csv", + language="en", +) + + +config_coqui_common_voice_metafile_tr_validated_tr = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_tr_validated.csv", + language="tr", +) + + +config_coqui_common_voice_metafile_ru_validated_ru = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_ru_validated.csv", + language="ru", +) + + +config_coqui_common_voice_metafile_nl_validated_nl = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_nl_validated.csv", + language="nl", +) + + +config_coqui_common_voice_metafile_cs_validated_cs = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_cs_validated.csv", + language="cs", +) + + +config_coqui_common_voice_metafile_fr_validated_fr = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_fr_validated.csv", + language="fr", +) + + +config_coqui_common_voice_metafile_es_validated_es = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_es_validated.csv", + language="es", +) + + +config_coqui_common_voice_metafile_pl_validated_pl = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_pl_validated.csv", + language="pl", +) + + +config_coqui_common_voice_metafile_ar_validated_ar = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_ar_validated.csv", + language="ar", +) + + +config_coqui_common_voice_metafile_zh_CN_validated_zh_cn = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_zh-CN_validated.csv", + language="zh-cn", +) + + +config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig( + formatter="coqui", + dataset_name="coqui", + path="/raid/datasets/common_voice/", + meta_file_train="/raid/datasets/common_voice/metafile_ja_validated.csv", + language="ja", +) + +# DATASETS_CONFIG_LIST=[config_coqui_MLS_metadata_train_with_previous_audio_key_de, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_MLS_metadata_dev_with_previous_audio_key_de, config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it, config_coqui_mls_portuguese_metadata_with_previous_audio_key_pt, config_coqui_mls_polish_metadata_with_previous_audio_key_pl, config_coqui_common_voice_metafile_it_train_with_scores_it, config_coqui_common_voice_metafile_it_test_with_scores_it, config_coqui_common_voice_metafile_it_dev_with_scores_it, config_coqui_common_voice_metafile_pt_train_with_scores_pt, config_coqui_common_voice_metafile_pt_test_with_scores_pt, config_coqui_common_voice_metafile_pt_dev_with_scores_pt, config_coqui_common_voice_metafile_en_train_en, config_coqui_common_voice_metafile_en_test_en, config_coqui_common_voice_metafile_en_dev_en, config_coqui_common_voice_metafile_tr_validated_tr, config_coqui_common_voice_metafile_ru_validated_ru, config_coqui_common_voice_metafile_nl_validated_nl, config_coqui_common_voice_metafile_cs_validated_cs, config_coqui_common_voice_metafile_fr_validated_fr, config_coqui_common_voice_metafile_es_validated_es, config_coqui_common_voice_metafile_pl_validated_pl, config_coqui_common_voice_metafile_ar_validated_ar, config_coqui_common_voice_metafile_zh_CN_validated_zh_cn, config_coqui_common_voice_metafile_ja_validated_ja] + +# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it] + +DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it] + +def freeze_layers(trainer): + pass + +def main(): + # init args and config + model_args = GPTArgs( + max_conditioning_length=132300, # 6 secs + min_conditioning_length=66150, # 3 secs + debug_loading_failures=True, + max_wav_length=255995, # ~11.6 seconds + max_text_length=200, + tokenizer_file="/raid/datasets/xtts_models/vocab.json", + mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth", + dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth", + gpt_checkpoint="/raid/datasets/xtts_models/gpt.pth", + gpt_num_audio_tokens=8194, + gpt_start_audio_token=8192, + gpt_stop_audio_token=8193, + ) + audio_config = XttsAudioConfig( + sample_rate=22050, # autoregressive SR + dvae_sample_rate=22050, + diffusion_sample_rate=24000, + output_sample_rate=24000 + ) + config = GPTConfig( + output_path=OUT_PATH, + model_args=model_args, + run_name=RUN_NAME, + project_name=PROJECT_NAME, + run_description=""" + GPT XTTS training + """, + dashboard_logger=DASHBOARD_LOGGER, + logger_uri=LOGGER_URI, + audio=audio_config, + batch_size=BATCH_SIZE, + batch_group_size=48, + eval_batch_size=BATCH_SIZE, + num_loader_workers=8, + eval_split_max_size=256, + print_step=50, + plot_step=100, + log_model_step=1000, + save_step=10000, + save_n_checkpoints=1, + save_checkpoints=True, + # target_loss="loss", + print_eval=False, + # Optimizer values like tortoise. However, they used pytorch implementation with modifications to not apply WD to non-weight parameters. We are using default Pytorch + optimizer="AdamW", + optimizer_wd_only_on_weights=True, + optimizer_params={"betas": [.9, .96], "eps": 1e-8, "weight_decay": 1e-2}, + lr=5e-06, # learning rate + # lr=1e-4, # learning rate + # ToDo: implement 500 step warmup like tortoise and EMA weights replaces LR decay with rate: .999 + lr_scheduler="MultiStepLR", + # it was adjusted accordly for the new step scheme + lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, + ) + + # init the model from config + model = GPTTrainer.init_from_config(config) + + # load training samples + train_samples, eval_samples = load_tts_samples( + DATASETS_CONFIG_LIST, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the trainer and 🚀 + trainer = Trainer( + TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=START_WITH_EVAL, grad_accum_steps=GRAD_ACUMM_STEPS), + config, + output_path=OUT_PATH, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + callbacks={"on_epoch_start": freeze_layers} + ) + trainer.fit() + + +if __name__ == "__main__": + RUN_NAME = "GPT_XTTS" + PROJECT_NAME = "XTTS" + OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/" + DASHBOARD_LOGGER = "clearml" + LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/" + RESTORE_PATH = None + SKIP_TRAIN_EPOCH = False + START_WITH_EVAL = True + BATCH_SIZE = 9 + GRAD_ACUMM_STEPS = 28 + + # debug + DASHBOARD_LOGGER = "tensorboard" + LOGGER_URI = None + RESTORE_PATH = None + BATCH_SIZE = 2 + GRAD_ACUMM_STEPS = 1 + NUM_LOADERS = 1 + + + + main()