Update GPT Trainer for perceiver support

This commit is contained in:
Edresson Casanova 2023-11-01 13:37:06 -03:00 committed by Eren G??lge
parent dff3902ca8
commit 8479a3702c
4 changed files with 17 additions and 12 deletions

View File

@ -494,10 +494,9 @@ class GPT(nn.Module):
# Compute speech conditioning input
if cond_latents is None:
if cond_lens is not None:
if cond_lens is not None and cond_idxs is None:
min_cond_len = torch.min(cond_lens)
cond_mels = cond_mels[:, :, :, :min_cond_len]
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
# Get logits

View File

@ -88,7 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset):
self.sample_rate = sample_rate
self.max_wav_len = model_args.max_wav_length
self.max_text_len = model_args.max_text_length
self.use_masking_gt_as_prompt = model_args.use_masking_gt_as_prompt
self.use_masking_gt_as_prompt = model_args.gpt_use_masking_gt_as_prompt
assert self.max_wav_len is not None and self.max_text_len is not None
self.samples = samples
@ -143,16 +143,18 @@ class XTTSDataset(torch.utils.data.Dataset):
if self.use_masking_gt_as_prompt:
# get a slice from GT to condition the model
cond, cond_len, cond_idxs = get_prompt_slice(
cond, _, cond_idxs = get_prompt_slice(
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
)
# if use masking do not use cond_len
cond_len = torch.nan
else:
ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath
cond, cond_len, cond_idxs = get_prompt_slice(
cond, cond_len, _ = get_prompt_slice(
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
)
# if do not use masking use cond_len
cond_idxs = torch.nan
cond_len = torch.nan
return tseq, audiopath, wav, cond, cond_len, cond_idxs
@ -230,9 +232,12 @@ class XTTSDataset(torch.utils.data.Dataset):
batch["conditioning"] = torch.stack(batch["conditioning"])
batch["cond_lens"] = torch.stack(batch["cond_lens"])
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
if torch.any(batch["cond_idxs"].isnan()):
batch["cond_lens"] = None
batch["cond_idxs"] = None
if torch.any(batch["cond_lens"].isnan()):
batch["cond_lens"] = None
max_text_len = batch["text_lengths"].max()
max_wav_len = batch["wav_lengths"].max()

View File

@ -52,7 +52,6 @@ class GPTArgs(XttsArgs):
xtts_checkpoint: str = ""
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
use_masking_gt_as_prompt: bool = True
def callback_clearml_load_save(operation_type, model_info):
@ -200,7 +199,7 @@ class GPTTrainer(BaseTTS):
def device(self):
return next(self.parameters()).device
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
@ -211,9 +210,10 @@ class GPTTrainer(BaseTTS):
wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2)
cond_lens: long tensor, (b,)
"""
losses = self.xtts.gpt(
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs, cond_lens=cond_lens,
)
return losses
@ -283,7 +283,6 @@ class GPTTrainer(BaseTTS):
del batch["padded_text"]
del batch["wav"]
del batch["conditioning"]
del batch["cond_lens"]
return batch
def train_step(self, batch, criterion):
@ -294,8 +293,9 @@ class GPTTrainer(BaseTTS):
audio_codes = batch["audio_codes"]
wav_lengths = batch["wav_lengths"]
cond_idxs = batch["cond_idxs"]
cond_lens = batch["cond_lens"]
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens)
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]

View File

@ -241,6 +241,7 @@ class XttsArgs(Coqpit):
gpt_num_audio_tokens: int = 8194
gpt_start_audio_token: int = 8192
gpt_stop_audio_token: int = 8193
gpt_use_masking_gt_as_prompt: bool = True
gpt_use_perceiver_resampler: bool = False
# Diffusion Decoder params