Add prompting masking

This commit is contained in:
Edresson Casanova 2023-10-16 09:06:59 -03:00
parent 47d613df3a
commit bafab049c2
3 changed files with 46 additions and 39 deletions

View File

@ -233,6 +233,7 @@ class GPT(nn.Module):
prompt=None,
get_attns=False,
return_latent=False,
attn_mask_cond=None,
attn_mask_text=None,
attn_mask_mel=None,
):
@ -248,8 +249,12 @@ class GPT(nn.Module):
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
if attn_mask_cond is not None:
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
else:
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
gpt_out = self.gpt(
inputs_embeds=emb,
@ -326,7 +331,7 @@ class GPT(nn.Module):
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
return prompt
def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
def get_style_emb(self, cond_input, return_latent=False):
"""
cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s)
@ -335,26 +340,7 @@ class GPT(nn.Module):
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
if sample:
_len_secs = random.randint(2, 6) # in secs
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
if cond_input.shape[-1] >= cond_seg_len:
new_conds = []
for i in range(cond_input.shape[0]):
cond_len = int(cond_lens[i] / 1024)
if cond_len < cond_seg_len:
start = 0
else:
start = random.randint(0, cond_len - cond_seg_len)
cond_vec = cond_input[i, :, start : start + cond_seg_len]
new_conds.append(cond_vec)
conds = torch.stack(new_conds, dim=0)
else:
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
cond_frame_len = int((22050 / 1024) * cond_seg_len)
conds = cond_input[:, :, -cond_frame_len:]
conds = self.conditioning_encoder(conds)
conds = self.conditioning_encoder(cond_input)
else:
# already computed
conds = cond_input.unsqueeze(1)
@ -366,10 +352,9 @@ class GPT(nn.Module):
text_lengths,
audio_codes,
wav_lengths,
cond_lens=None,
cond_mels=None,
cond_idxs=None,
cond_latents=None,
loss_weights=None,
return_attentions=False,
return_latent=False,
):
@ -377,11 +362,12 @@ class GPT(nn.Module):
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
cond_mels: MEL float tensor, (b, 1, 80,s)
text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, 1, 80,s)
cond_idxs: cond start and end indexs, (b, 2)
If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
@ -393,6 +379,11 @@ class GPT(nn.Module):
max_text_len = text_lengths.max()
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
if cond_idxs is not None:
# recompute cond idxs for mel lengths
for idx, l in enumerate(code_lengths):
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max()
@ -435,9 +426,16 @@ class GPT(nn.Module):
)
# Set attn_mask
attn_mask_cond = None
attn_mask_text = None
attn_mask_mel = None
if not return_latent:
attn_mask_cond = torch.ones(
cond_mels.shape[0],
cond_mels.shape[-1],
dtype=torch.bool,
device=text_inputs.device,
)
attn_mask_text = torch.ones(
text_inputs.shape[0],
text_inputs.shape[1],
@ -451,6 +449,11 @@ class GPT(nn.Module):
device=audio_codes.device,
)
if cond_idxs is not None:
for idx, r in enumerate(cond_idxs.squeeze()):
l = r[1] - r[0]
attn_mask_cond[idx, l : ] = 0.0
for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0
@ -465,7 +468,7 @@ class GPT(nn.Module):
# Compute speech conditioning input
if cond_latents is None:
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
# Get logits
sub = -5 # don't ask me why 😄
@ -480,6 +483,7 @@ class GPT(nn.Module):
prompt=cond_latents,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_cond=attn_mask_cond,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)
@ -495,12 +499,19 @@ class GPT(nn.Module):
for idx, l in enumerate(code_lengths):
mel_targets[idx, l + 1 :] = -1
# check if stoptoken is in every row of mel_targets
assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[
0
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
# ignore the loss for the segment used for conditioning
# coin flip for the segment to be ignored
if cond_idxs is not None:
cond_start = cond_idxs[idx, 0]
cond_end = cond_idxs[idx, 1]
mel_targets[idx, cond_start:cond_end] = -1
# Compute losses
loss_text = F.cross_entropy(
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing

View File

@ -36,7 +36,6 @@ class GPTConfig(TortoiseConfig):
lr: float = 5e-06
training_seed: int = 1
optimizer_wd_only_on_weights: bool = False
use_weighted_loss: bool = False # TODO: move it to the base config
weighted_loss_attrs: dict = field(default_factory=lambda: {})
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
@ -200,7 +199,7 @@ class GPTTrainer(BaseTTS):
def device(self):
return next(self.parameters()).device
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens):
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
"""
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`).
@ -209,10 +208,10 @@ class GPTTrainer(BaseTTS):
text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
cond_idxs: cond start and end indexs, (b, 2)
cond_lengths: long tensor, (b,)
"""
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_lens=cond_lens)
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
return losses
@torch.no_grad()
@ -228,7 +227,6 @@ class GPTTrainer(BaseTTS):
batch["text_lengths"] = batch["text_lengths"]
batch["wav_lengths"] = batch["wav_lengths"]
batch["text_inputs"] = batch["padded_text"]
batch["cond_lens"] = batch["cond_lens"]
batch["cond_idxs"] = batch["cond_idxs"]
# compute conditioning mel specs
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
@ -261,7 +259,7 @@ class GPTTrainer(BaseTTS):
del batch["padded_text"]
del batch["wav"]
del batch["conditioning"]
del batch["cond_lens"]
return batch
def train_step(self, batch, criterion):
@ -272,11 +270,10 @@ class GPTTrainer(BaseTTS):
audio_codes = batch["audio_codes"]
wav_lengths = batch["wav_lengths"]
cond_lens=batch["cond_lens"]
# Todo: implement masking on the cond slice
cond_idxs = batch["cond_idxs"]
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens)
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
@ -325,7 +322,6 @@ class GPTTrainer(BaseTTS):
if is_eval and not config.run_eval:
loader = None
else:
# Todo: remove the randomness of dataset when it is eval
# init dataloader
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)

View File

@ -253,7 +253,7 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
def freeze_layers(trainer):
pass
@ -262,7 +262,7 @@ def main():
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=True,
debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds
max_text_length=200,
tokenizer_file="/raid/datasets/xtts_models/vocab.json",