mirror of https://github.com/coqui-ai/TTS.git
Bug fix on reproducible evaluation
This commit is contained in:
parent
bafab049c2
commit
2f868dd5c2
|
@ -255,7 +255,6 @@ class GPT(nn.Module):
|
|||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
return_dict=True,
|
||||
|
|
|
@ -23,14 +23,24 @@ def key_samples_by_col(samples, col):
|
|||
return samples_by_col
|
||||
|
||||
|
||||
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate):
|
||||
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
|
||||
rel_clip = load_audio(gt_path, sample_rate)
|
||||
# if eval uses a middle size sample when it is possible to be more reproducible
|
||||
if is_eval:
|
||||
sample_length = int((min_sample_length + max_sample_length)/2)
|
||||
else:
|
||||
sample_length = random.randint(min_sample_length, max_sample_length)
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
if gap < 0:
|
||||
sample_length = rel_clip.shape[-1] // 2
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
|
||||
# if eval start always from the position 0 to be more reproducible
|
||||
if is_eval:
|
||||
rand_start = 0
|
||||
else:
|
||||
rand_start = random.randint(0, gap)
|
||||
|
||||
rand_end = rand_start+sample_length
|
||||
rel_clip = rel_clip[:, rand_start:rand_end]
|
||||
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
||||
|
@ -90,7 +100,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
self.check_eval_samples()
|
||||
|
||||
def check_eval_samples(self):
|
||||
print("Filtering invalid eval samples!!")
|
||||
print(" > Filtering invalid eval samples!!")
|
||||
new_samples = []
|
||||
for sample in self.samples:
|
||||
try:
|
||||
|
@ -104,7 +114,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
continue
|
||||
new_samples.append(sample)
|
||||
self.samples = new_samples
|
||||
print("Total eval samples after filtering:", len(self.samples))
|
||||
print(" > Total eval samples after filtering:", len(self.samples))
|
||||
|
||||
def get_text(self, text, lang):
|
||||
tokens = self.tokenizer.encode(text, lang)
|
||||
|
@ -126,7 +136,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|||
raise ValueError
|
||||
|
||||
# get a slice from GT to condition the model
|
||||
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate)
|
||||
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval)
|
||||
|
||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||
|
||||
|
|
|
@ -208,8 +208,8 @@ class GPTTrainer(BaseTTS):
|
|||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||
cond_idxs: cond start and end indexs, (b, 2)
|
||||
cond_lengths: long tensor, (b,)
|
||||
"""
|
||||
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
||||
return losses
|
||||
|
@ -269,8 +269,6 @@ class GPTTrainer(BaseTTS):
|
|||
text_lengths = batch["text_lengths"]
|
||||
audio_codes = batch["audio_codes"]
|
||||
wav_lengths = batch["wav_lengths"]
|
||||
|
||||
# Todo: implement masking on the cond slice
|
||||
cond_idxs = batch["cond_idxs"]
|
||||
|
||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
|
||||
|
@ -280,6 +278,8 @@ class GPTTrainer(BaseTTS):
|
|||
return {"model_outputs": None}, loss_dict
|
||||
|
||||
def eval_step(self, batch, criterion):
|
||||
# ignore masking for more consistent evaluation
|
||||
batch["cond_idxs"] = None
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
||||
|
|
Loading…
Reference in New Issue