diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index dce8a137..52086f13 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -254,7 +254,6 @@ class GPT(nn.Module): else: attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1) - gpt_out = self.gpt( inputs_embeds=emb, diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py index 3ac22b5d..b122fc8a 100644 --- a/TTS/tts/layers/xtts/trainer/dataset.py +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -23,14 +23,24 @@ def key_samples_by_col(samples, col): return samples_by_col -def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate): +def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False): rel_clip = load_audio(gt_path, sample_rate) - sample_length = random.randint(min_sample_length, max_sample_length) + # if eval uses a middle size sample when it is possible to be more reproducible + if is_eval: + sample_length = int((min_sample_length + max_sample_length)/2) + else: + sample_length = random.randint(min_sample_length, max_sample_length) gap = rel_clip.shape[-1] - sample_length if gap < 0: sample_length = rel_clip.shape[-1] // 2 gap = rel_clip.shape[-1] - sample_length - rand_start = random.randint(0, gap) + + # if eval start always from the position 0 to be more reproducible + if is_eval: + rand_start = 0 + else: + rand_start = random.randint(0, gap) + rand_end = rand_start+sample_length rel_clip = rel_clip[:, rand_start:rand_end] rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1])) @@ -90,7 +100,7 @@ class XTTSDataset(torch.utils.data.Dataset): self.check_eval_samples() def check_eval_samples(self): - print("Filtering invalid eval samples!!") + print(" > Filtering invalid eval samples!!") new_samples = [] for sample in self.samples: try: @@ -104,7 +114,7 @@ class XTTSDataset(torch.utils.data.Dataset): continue new_samples.append(sample) self.samples = new_samples - print("Total eval samples after filtering:", len(self.samples)) + print(" > Total eval samples after filtering:", len(self.samples)) def get_text(self, text, lang): tokens = self.tokenizer.encode(text, lang) @@ -126,7 +136,7 @@ class XTTSDataset(torch.utils.data.Dataset): raise ValueError # get a slice from GT to condition the model - cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate) + cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval) return tseq, audiopath, wav, cond, cond_len, cond_idxs diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index e8e5752f..d884f12a 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -208,8 +208,8 @@ class GPTTrainer(BaseTTS): text_lengths: long tensor, (b,) mel_inputs: long tensor, (b,m) wav_lengths: long tensor, (b,) + cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_idxs: cond start and end indexs, (b, 2) - cond_lengths: long tensor, (b,) """ losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs) return losses @@ -269,10 +269,8 @@ class GPTTrainer(BaseTTS): text_lengths = batch["text_lengths"] audio_codes = batch["audio_codes"] wav_lengths = batch["wav_lengths"] - - # Todo: implement masking on the cond slice cond_idxs = batch["cond_idxs"] - + loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs) loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight @@ -280,6 +278,8 @@ class GPTTrainer(BaseTTS): return {"model_outputs": None}, loss_dict def eval_step(self, batch, criterion): + # ignore masking for more consistent evaluation + batch["cond_idxs"] = None return self.train_step(batch, criterion) def on_epoch_start(self, trainer): # pylint: disable=W0613