Update GPT Trainer for perceiver support

This commit is contained in:
Edresson Casanova 2023-11-01 13:37:06 -03:00 committed by Eren G??lge
parent dff3902ca8
commit 8479a3702c
4 changed files with 17 additions and 12 deletions

View File

@ -494,10 +494,9 @@ class GPT(nn.Module):
# Compute speech conditioning input # Compute speech conditioning input
if cond_latents is None: if cond_latents is None:
if cond_lens is not None: if cond_lens is not None and cond_idxs is None:
min_cond_len = torch.min(cond_lens) min_cond_len = torch.min(cond_lens)
cond_mels = cond_mels[:, :, :, :min_cond_len] cond_mels = cond_mels[:, :, :, :min_cond_len]
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2) cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
# Get logits # Get logits

View File

@ -88,7 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset):
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.max_wav_len = model_args.max_wav_length self.max_wav_len = model_args.max_wav_length
self.max_text_len = model_args.max_text_length self.max_text_len = model_args.max_text_length
self.use_masking_gt_as_prompt = model_args.use_masking_gt_as_prompt self.use_masking_gt_as_prompt = model_args.gpt_use_masking_gt_as_prompt
assert self.max_wav_len is not None and self.max_text_len is not None assert self.max_wav_len is not None and self.max_text_len is not None
self.samples = samples self.samples = samples
@ -143,16 +143,18 @@ class XTTSDataset(torch.utils.data.Dataset):
if self.use_masking_gt_as_prompt: if self.use_masking_gt_as_prompt:
# get a slice from GT to condition the model # get a slice from GT to condition the model
cond, cond_len, cond_idxs = get_prompt_slice( cond, _, cond_idxs = get_prompt_slice(
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
) )
# if use masking do not use cond_len
cond_len = torch.nan
else: else:
ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath
cond, cond_len, cond_idxs = get_prompt_slice( cond, cond_len, _ = get_prompt_slice(
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
) )
# if do not use masking use cond_len
cond_idxs = torch.nan cond_idxs = torch.nan
cond_len = torch.nan
return tseq, audiopath, wav, cond, cond_len, cond_idxs return tseq, audiopath, wav, cond, cond_len, cond_idxs
@ -230,9 +232,12 @@ class XTTSDataset(torch.utils.data.Dataset):
batch["conditioning"] = torch.stack(batch["conditioning"]) batch["conditioning"] = torch.stack(batch["conditioning"])
batch["cond_lens"] = torch.stack(batch["cond_lens"]) batch["cond_lens"] = torch.stack(batch["cond_lens"])
batch["cond_idxs"] = torch.stack(batch["cond_idxs"]) batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
if torch.any(batch["cond_idxs"].isnan()): if torch.any(batch["cond_idxs"].isnan()):
batch["cond_lens"] = None batch["cond_lens"] = None
batch["cond_idxs"] = None
if torch.any(batch["cond_lens"].isnan()):
batch["cond_lens"] = None
max_text_len = batch["text_lengths"].max() max_text_len = batch["text_lengths"].max()
max_wav_len = batch["wav_lengths"].max() max_wav_len = batch["wav_lengths"].max()

View File

@ -52,7 +52,6 @@ class GPTArgs(XttsArgs):
xtts_checkpoint: str = "" xtts_checkpoint: str = ""
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
vocoder: str = "" # overide vocoder key on the config to avoid json write issues vocoder: str = "" # overide vocoder key on the config to avoid json write issues
use_masking_gt_as_prompt: bool = True
def callback_clearml_load_save(operation_type, model_info): def callback_clearml_load_save(operation_type, model_info):
@ -200,7 +199,7 @@ class GPTTrainer(BaseTTS):
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs): def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -211,9 +210,10 @@ class GPTTrainer(BaseTTS):
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2) cond_idxs: cond start and end indexs, (b, 2)
cond_lens: long tensor, (b,)
""" """
losses = self.xtts.gpt( losses = self.xtts.gpt(
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs, cond_lens=cond_lens,
) )
return losses return losses
@ -283,7 +283,6 @@ class GPTTrainer(BaseTTS):
del batch["padded_text"] del batch["padded_text"]
del batch["wav"] del batch["wav"]
del batch["conditioning"] del batch["conditioning"]
del batch["cond_lens"]
return batch return batch
def train_step(self, batch, criterion): def train_step(self, batch, criterion):
@ -294,8 +293,9 @@ class GPTTrainer(BaseTTS):
audio_codes = batch["audio_codes"] audio_codes = batch["audio_codes"]
wav_lengths = batch["wav_lengths"] wav_lengths = batch["wav_lengths"]
cond_idxs = batch["cond_idxs"] cond_idxs = batch["cond_idxs"]
cond_lens = batch["cond_lens"]
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs) loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens)
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]

View File

@ -241,6 +241,7 @@ class XttsArgs(Coqpit):
gpt_num_audio_tokens: int = 8194 gpt_num_audio_tokens: int = 8194
gpt_start_audio_token: int = 8192 gpt_start_audio_token: int = 8192
gpt_stop_audio_token: int = 8193 gpt_stop_audio_token: int = 8193
gpt_use_masking_gt_as_prompt: bool = True
gpt_use_perceiver_resampler: bool = False gpt_use_perceiver_resampler: bool = False
# Diffusion Decoder params # Diffusion Decoder params