mirror of https://github.com/coqui-ai/TTS.git
Bug fix on reproducible evaluation
This commit is contained in:
parent
bafab049c2
commit
2f868dd5c2
|
@ -255,7 +255,6 @@ class GPT(nn.Module):
|
||||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||||
|
|
||||||
|
|
||||||
gpt_out = self.gpt(
|
gpt_out = self.gpt(
|
||||||
inputs_embeds=emb,
|
inputs_embeds=emb,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
|
|
|
@ -23,14 +23,24 @@ def key_samples_by_col(samples, col):
|
||||||
return samples_by_col
|
return samples_by_col
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate):
|
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
|
||||||
rel_clip = load_audio(gt_path, sample_rate)
|
rel_clip = load_audio(gt_path, sample_rate)
|
||||||
|
# if eval uses a middle size sample when it is possible to be more reproducible
|
||||||
|
if is_eval:
|
||||||
|
sample_length = int((min_sample_length + max_sample_length)/2)
|
||||||
|
else:
|
||||||
sample_length = random.randint(min_sample_length, max_sample_length)
|
sample_length = random.randint(min_sample_length, max_sample_length)
|
||||||
gap = rel_clip.shape[-1] - sample_length
|
gap = rel_clip.shape[-1] - sample_length
|
||||||
if gap < 0:
|
if gap < 0:
|
||||||
sample_length = rel_clip.shape[-1] // 2
|
sample_length = rel_clip.shape[-1] // 2
|
||||||
gap = rel_clip.shape[-1] - sample_length
|
gap = rel_clip.shape[-1] - sample_length
|
||||||
|
|
||||||
|
# if eval start always from the position 0 to be more reproducible
|
||||||
|
if is_eval:
|
||||||
|
rand_start = 0
|
||||||
|
else:
|
||||||
rand_start = random.randint(0, gap)
|
rand_start = random.randint(0, gap)
|
||||||
|
|
||||||
rand_end = rand_start+sample_length
|
rand_end = rand_start+sample_length
|
||||||
rel_clip = rel_clip[:, rand_start:rand_end]
|
rel_clip = rel_clip[:, rand_start:rand_end]
|
||||||
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
||||||
|
@ -90,7 +100,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
self.check_eval_samples()
|
self.check_eval_samples()
|
||||||
|
|
||||||
def check_eval_samples(self):
|
def check_eval_samples(self):
|
||||||
print("Filtering invalid eval samples!!")
|
print(" > Filtering invalid eval samples!!")
|
||||||
new_samples = []
|
new_samples = []
|
||||||
for sample in self.samples:
|
for sample in self.samples:
|
||||||
try:
|
try:
|
||||||
|
@ -104,7 +114,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
continue
|
continue
|
||||||
new_samples.append(sample)
|
new_samples.append(sample)
|
||||||
self.samples = new_samples
|
self.samples = new_samples
|
||||||
print("Total eval samples after filtering:", len(self.samples))
|
print(" > Total eval samples after filtering:", len(self.samples))
|
||||||
|
|
||||||
def get_text(self, text, lang):
|
def get_text(self, text, lang):
|
||||||
tokens = self.tokenizer.encode(text, lang)
|
tokens = self.tokenizer.encode(text, lang)
|
||||||
|
@ -126,7 +136,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
# get a slice from GT to condition the model
|
# get a slice from GT to condition the model
|
||||||
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate)
|
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval)
|
||||||
|
|
||||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||||
|
|
||||||
|
|
|
@ -208,8 +208,8 @@ class GPTTrainer(BaseTTS):
|
||||||
text_lengths: long tensor, (b,)
|
text_lengths: long tensor, (b,)
|
||||||
mel_inputs: long tensor, (b,m)
|
mel_inputs: long tensor, (b,m)
|
||||||
wav_lengths: long tensor, (b,)
|
wav_lengths: long tensor, (b,)
|
||||||
|
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||||
cond_idxs: cond start and end indexs, (b, 2)
|
cond_idxs: cond start and end indexs, (b, 2)
|
||||||
cond_lengths: long tensor, (b,)
|
|
||||||
"""
|
"""
|
||||||
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
||||||
return losses
|
return losses
|
||||||
|
@ -269,8 +269,6 @@ class GPTTrainer(BaseTTS):
|
||||||
text_lengths = batch["text_lengths"]
|
text_lengths = batch["text_lengths"]
|
||||||
audio_codes = batch["audio_codes"]
|
audio_codes = batch["audio_codes"]
|
||||||
wav_lengths = batch["wav_lengths"]
|
wav_lengths = batch["wav_lengths"]
|
||||||
|
|
||||||
# Todo: implement masking on the cond slice
|
|
||||||
cond_idxs = batch["cond_idxs"]
|
cond_idxs = batch["cond_idxs"]
|
||||||
|
|
||||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
|
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
|
||||||
|
@ -280,6 +278,8 @@ class GPTTrainer(BaseTTS):
|
||||||
return {"model_outputs": None}, loss_dict
|
return {"model_outputs": None}, loss_dict
|
||||||
|
|
||||||
def eval_step(self, batch, criterion):
|
def eval_step(self, batch, criterion):
|
||||||
|
# ignore masking for more consistent evaluation
|
||||||
|
batch["cond_idxs"] = None
|
||||||
return self.train_step(batch, criterion)
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
||||||
|
|
Loading…
Reference in New Issue