From bafab049c210263c26e13659c6445ccb213f2d67 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 16 Oct 2023 09:06:59 -0300 Subject: [PATCH] Add prompting masking --- TTS/tts/layers/xtts/gpt.py | 67 +++++++++++++--------- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 14 ++--- recipes/multilingual/xtts_v1/train_xtts.py | 4 +- 3 files changed, 46 insertions(+), 39 deletions(-) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 88ce100c..dce8a137 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -233,6 +233,7 @@ class GPT(nn.Module): prompt=None, get_attns=False, return_latent=False, + attn_mask_cond=None, attn_mask_text=None, attn_mask_mel=None, ): @@ -248,8 +249,12 @@ class GPT(nn.Module): if attn_mask_text is not None: attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) if prompt is not None: - attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) - attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1) + if attn_mask_cond is not None: + attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1) + else: + attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) + attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1) + gpt_out = self.gpt( inputs_embeds=emb, @@ -326,7 +331,7 @@ class GPT(nn.Module): prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) return prompt - def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True): + def get_style_emb(self, cond_input, return_latent=False): """ cond_input: (b, 80, s) or (b, 1, 80, s) conds: (b, 1024, s) @@ -335,26 +340,7 @@ class GPT(nn.Module): if not return_latent: if cond_input.ndim == 4: cond_input = cond_input.squeeze(1) - if sample: - _len_secs = random.randint(2, 6) # in secs - cond_seg_len = int((22050 / 1024) * _len_secs) # in frames - if cond_input.shape[-1] >= cond_seg_len: - new_conds = [] - for i in range(cond_input.shape[0]): - cond_len = int(cond_lens[i] / 1024) - if cond_len < cond_seg_len: - start = 0 - else: - start = random.randint(0, cond_len - cond_seg_len) - cond_vec = cond_input[i, :, start : start + cond_seg_len] - new_conds.append(cond_vec) - conds = torch.stack(new_conds, dim=0) - else: - cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs - cond_frame_len = int((22050 / 1024) * cond_seg_len) - conds = cond_input[:, :, -cond_frame_len:] - - conds = self.conditioning_encoder(conds) + conds = self.conditioning_encoder(cond_input) else: # already computed conds = cond_input.unsqueeze(1) @@ -366,10 +352,9 @@ class GPT(nn.Module): text_lengths, audio_codes, wav_lengths, - cond_lens=None, cond_mels=None, + cond_idxs=None, cond_latents=None, - loss_weights=None, return_attentions=False, return_latent=False, ): @@ -377,11 +362,12 @@ class GPT(nn.Module): Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). - cond_mels: MEL float tensor, (b, 1, 80,s) text_inputs: long tensor, (b,t) text_lengths: long tensor, (b,) mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) + cond_mels: MEL float tensor, (b, 1, 80,s) + cond_idxs: cond start and end indexs, (b, 2) If return_attentions is specified, only logits are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. @@ -393,6 +379,11 @@ class GPT(nn.Module): max_text_len = text_lengths.max() code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3 + if cond_idxs is not None: + # recompute cond idxs for mel lengths + for idx, l in enumerate(code_lengths): + cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len + # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. max_mel_len = code_lengths.max() @@ -435,9 +426,16 @@ class GPT(nn.Module): ) # Set attn_mask + attn_mask_cond = None attn_mask_text = None attn_mask_mel = None if not return_latent: + attn_mask_cond = torch.ones( + cond_mels.shape[0], + cond_mels.shape[-1], + dtype=torch.bool, + device=text_inputs.device, + ) attn_mask_text = torch.ones( text_inputs.shape[0], text_inputs.shape[1], @@ -451,6 +449,11 @@ class GPT(nn.Module): device=audio_codes.device, ) + if cond_idxs is not None: + for idx, r in enumerate(cond_idxs.squeeze()): + l = r[1] - r[0] + attn_mask_cond[idx, l : ] = 0.0 + for idx, l in enumerate(text_lengths): attn_mask_text[idx, l + 1 :] = 0.0 @@ -465,7 +468,7 @@ class GPT(nn.Module): # Compute speech conditioning input if cond_latents is None: - cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2) + cond_latents = self.get_style_emb(cond_mels).transpose(1, 2) # Get logits sub = -5 # don't ask me why 😄 @@ -480,6 +483,7 @@ class GPT(nn.Module): prompt=cond_latents, get_attns=return_attentions, return_latent=return_latent, + attn_mask_cond=attn_mask_cond, attn_mask_text=attn_mask_text, attn_mask_mel=attn_mask_mel, ) @@ -495,12 +499,19 @@ class GPT(nn.Module): for idx, l in enumerate(code_lengths): mel_targets[idx, l + 1 :] = -1 - + # check if stoptoken is in every row of mel_targets assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[ 0 ], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row." + # ignore the loss for the segment used for conditioning + # coin flip for the segment to be ignored + if cond_idxs is not None: + cond_start = cond_idxs[idx, 0] + cond_end = cond_idxs[idx, 1] + mel_targets[idx, cond_start:cond_end] = -1 + # Compute losses loss_text = F.cross_entropy( text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 71cfd6e4..e8e5752f 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -36,7 +36,6 @@ class GPTConfig(TortoiseConfig): lr: float = 5e-06 training_seed: int = 1 optimizer_wd_only_on_weights: bool = False - use_weighted_loss: bool = False # TODO: move it to the base config weighted_loss_attrs: dict = field(default_factory=lambda: {}) weighted_loss_multipliers: dict = field(default_factory=lambda: {}) @@ -200,7 +199,7 @@ class GPTTrainer(BaseTTS): def device(self): return next(self.parameters()).device - def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens): + def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs): """ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode (actuated by `text_first`). @@ -209,10 +208,10 @@ class GPTTrainer(BaseTTS): text_lengths: long tensor, (b,) mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) - cond_mels: MEL float tensor, (b, num_samples, 80,t_m) + cond_idxs: cond start and end indexs, (b, 2) cond_lengths: long tensor, (b,) """ - losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_lens=cond_lens) + losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs) return losses @torch.no_grad() @@ -228,7 +227,6 @@ class GPTTrainer(BaseTTS): batch["text_lengths"] = batch["text_lengths"] batch["wav_lengths"] = batch["wav_lengths"] batch["text_inputs"] = batch["padded_text"] - batch["cond_lens"] = batch["cond_lens"] batch["cond_idxs"] = batch["cond_idxs"] # compute conditioning mel specs # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor @@ -261,7 +259,7 @@ class GPTTrainer(BaseTTS): del batch["padded_text"] del batch["wav"] del batch["conditioning"] - + del batch["cond_lens"] return batch def train_step(self, batch, criterion): @@ -272,11 +270,10 @@ class GPTTrainer(BaseTTS): audio_codes = batch["audio_codes"] wav_lengths = batch["wav_lengths"] - cond_lens=batch["cond_lens"] # Todo: implement masking on the cond slice cond_idxs = batch["cond_idxs"] - loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens) + loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs) loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] @@ -325,7 +322,6 @@ class GPTTrainer(BaseTTS): if is_eval and not config.run_eval: loader = None else: - # Todo: remove the randomness of dataset when it is eval # init dataloader dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval) diff --git a/recipes/multilingual/xtts_v1/train_xtts.py b/recipes/multilingual/xtts_v1/train_xtts.py index 4e987a4f..429e4e3a 100644 --- a/recipes/multilingual/xtts_v1/train_xtts.py +++ b/recipes/multilingual/xtts_v1/train_xtts.py @@ -253,7 +253,7 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig( # DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it] DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it] - + def freeze_layers(trainer): pass @@ -262,7 +262,7 @@ def main(): model_args = GPTArgs( max_conditioning_length=132300, # 6 secs min_conditioning_length=66150, # 3 secs - debug_loading_failures=True, + debug_loading_failures=False, max_wav_length=255995, # ~11.6 seconds max_text_length=200, tokenizer_file="/raid/datasets/xtts_models/vocab.json",