Bug fix on reproducible evaluation

This commit is contained in:
Edresson Casanova 2023-10-16 09:28:32 -03:00
parent bafab049c2
commit 2f868dd5c2
3 changed files with 20 additions and 11 deletions

View File

@ -255,7 +255,6 @@ class GPT(nn.Module):
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1) attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
gpt_out = self.gpt( gpt_out = self.gpt(
inputs_embeds=emb, inputs_embeds=emb,
return_dict=True, return_dict=True,

View File

@ -23,14 +23,24 @@ def key_samples_by_col(samples, col):
return samples_by_col return samples_by_col
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate): def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
rel_clip = load_audio(gt_path, sample_rate) rel_clip = load_audio(gt_path, sample_rate)
# if eval uses a middle size sample when it is possible to be more reproducible
if is_eval:
sample_length = int((min_sample_length + max_sample_length)/2)
else:
sample_length = random.randint(min_sample_length, max_sample_length) sample_length = random.randint(min_sample_length, max_sample_length)
gap = rel_clip.shape[-1] - sample_length gap = rel_clip.shape[-1] - sample_length
if gap < 0: if gap < 0:
sample_length = rel_clip.shape[-1] // 2 sample_length = rel_clip.shape[-1] // 2
gap = rel_clip.shape[-1] - sample_length gap = rel_clip.shape[-1] - sample_length
# if eval start always from the position 0 to be more reproducible
if is_eval:
rand_start = 0
else:
rand_start = random.randint(0, gap) rand_start = random.randint(0, gap)
rand_end = rand_start+sample_length rand_end = rand_start+sample_length
rel_clip = rel_clip[:, rand_start:rand_end] rel_clip = rel_clip[:, rand_start:rand_end]
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1])) rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
@ -90,7 +100,7 @@ class XTTSDataset(torch.utils.data.Dataset):
self.check_eval_samples() self.check_eval_samples()
def check_eval_samples(self): def check_eval_samples(self):
print("Filtering invalid eval samples!!") print(" > Filtering invalid eval samples!!")
new_samples = [] new_samples = []
for sample in self.samples: for sample in self.samples:
try: try:
@ -104,7 +114,7 @@ class XTTSDataset(torch.utils.data.Dataset):
continue continue
new_samples.append(sample) new_samples.append(sample)
self.samples = new_samples self.samples = new_samples
print("Total eval samples after filtering:", len(self.samples)) print(" > Total eval samples after filtering:", len(self.samples))
def get_text(self, text, lang): def get_text(self, text, lang):
tokens = self.tokenizer.encode(text, lang) tokens = self.tokenizer.encode(text, lang)
@ -126,7 +136,7 @@ class XTTSDataset(torch.utils.data.Dataset):
raise ValueError raise ValueError
# get a slice from GT to condition the model # get a slice from GT to condition the model
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate) cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval)
return tseq, audiopath, wav, cond, cond_len, cond_idxs return tseq, audiopath, wav, cond, cond_len, cond_idxs

View File

@ -208,8 +208,8 @@ class GPTTrainer(BaseTTS):
text_lengths: long tensor, (b,) text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2) cond_idxs: cond start and end indexs, (b, 2)
cond_lengths: long tensor, (b,)
""" """
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs) losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
return losses return losses
@ -269,8 +269,6 @@ class GPTTrainer(BaseTTS):
text_lengths = batch["text_lengths"] text_lengths = batch["text_lengths"]
audio_codes = batch["audio_codes"] audio_codes = batch["audio_codes"]
wav_lengths = batch["wav_lengths"] wav_lengths = batch["wav_lengths"]
# Todo: implement masking on the cond slice
cond_idxs = batch["cond_idxs"] cond_idxs = batch["cond_idxs"]
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs) loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
@ -280,6 +278,8 @@ class GPTTrainer(BaseTTS):
return {"model_outputs": None}, loss_dict return {"model_outputs": None}, loss_dict
def eval_step(self, batch, criterion): def eval_step(self, batch, criterion):
# ignore masking for more consistent evaluation
batch["cond_idxs"] = None
return self.train_step(batch, criterion) return self.train_step(batch, criterion)
def on_epoch_start(self, trainer): # pylint: disable=W0613 def on_epoch_start(self, trainer): # pylint: disable=W0613