From 8479a3702c6ba1ab1f36bda5dd996c1b2d3da1c3 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Wed, 1 Nov 2023 13:37:06 -0300 Subject: [PATCH] Update GPT Trainer for perceiver support --- TTS/tts/layers/xtts/gpt.py | 3 +-- TTS/tts/layers/xtts/trainer/dataset.py | 15 ++++++++++----- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 10 +++++----- TTS/tts/models/xtts.py | 1 + 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 6397fa1d..c3477d52 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -494,10 +494,9 @@ class GPT(nn.Module): # Compute speech conditioning input if cond_latents is None: - if cond_lens is not None: + if cond_lens is not None and cond_idxs is None: min_cond_len = torch.min(cond_lens) cond_mels = cond_mels[:, :, :, :min_cond_len] - cond_latents = self.get_style_emb(cond_mels).transpose(1, 2) # Get logits diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index e37387ce..3abfa43e 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -88,7 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset): self.sample_rate = sample_rate self.max_wav_len = model_args.max_wav_length self.max_text_len = model_args.max_text_length - self.use_masking_gt_as_prompt = model_args.use_masking_gt_as_prompt + self.use_masking_gt_as_prompt = model_args.gpt_use_masking_gt_as_prompt assert self.max_wav_len is not None and self.max_text_len is not None self.samples = samples @@ -143,16 +143,18 @@ class XTTSDataset(torch.utils.data.Dataset): if self.use_masking_gt_as_prompt: # get a slice from GT to condition the model - cond, cond_len, cond_idxs = get_prompt_slice( + cond, _, cond_idxs = get_prompt_slice( audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval ) + # if use masking do not use cond_len + cond_len = torch.nan else: ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath - cond, cond_len, cond_idxs = get_prompt_slice( + cond, cond_len, _ = get_prompt_slice( ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval ) + # if do not use masking use cond_len cond_idxs = torch.nan - cond_len = torch.nan return tseq, audiopath, wav, cond, cond_len, cond_idxs @@ -230,9 +232,12 @@ class XTTSDataset(torch.utils.data.Dataset): batch["conditioning"] = torch.stack(batch["conditioning"]) batch["cond_lens"] = torch.stack(batch["cond_lens"]) batch["cond_idxs"] = torch.stack(batch["cond_idxs"]) + if torch.any(batch["cond_idxs"].isnan()): batch["cond_lens"] = None - batch["cond_idxs"] = None + + if torch.any(batch["cond_lens"].isnan()): + batch["cond_lens"] = None max_text_len = batch["text_lengths"].max() max_wav_len = batch["wav_lengths"].max() diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 7b870548..99e4529b 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -52,7 +52,6 @@ class GPTArgs(XttsArgs): xtts_checkpoint: str = "" gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model vocoder: str = "" # overide vocoder key on the config to avoid json write issues - use_masking_gt_as_prompt: bool = True def callback_clearml_load_save(operation_type, model_info): @@ -200,7 +199,7 @@ class GPTTrainer(BaseTTS): def device(self): return next(self.parameters()).device - def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs): + def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). @@ -211,9 +210,10 @@ class GPTTrainer(BaseTTS): wav_lengths: long tensor, (b,) cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_idxs: cond start and end indexs, (b, 2) + cond_lens: long tensor, (b,) """ losses = self.xtts.gpt( - text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs + text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs, cond_lens=cond_lens, ) return losses @@ -283,7 +283,6 @@ class GPTTrainer(BaseTTS): del batch["padded_text"] del batch["wav"] del batch["conditioning"] - del batch["cond_lens"] return batch def train_step(self, batch, criterion): @@ -294,8 +293,9 @@ class GPTTrainer(BaseTTS): audio_codes = batch["audio_codes"] wav_lengths = batch["wav_lengths"] cond_idxs = batch["cond_idxs"] + cond_lens = batch["cond_lens"] - loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs) + loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens) loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index d8458109..dcf8ad3e 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -241,6 +241,7 @@ class XttsArgs(Coqpit): gpt_num_audio_tokens: int = 8194 gpt_start_audio_token: int = 8192 gpt_stop_audio_token: int = 8193 + gpt_use_masking_gt_as_prompt: bool = True gpt_use_perceiver_resampler: bool = False # Diffusion Decoder params