Add prompting masking

This commit is contained in:
Edresson Casanova 2023-10-16 09:06:59 -03:00
parent 47d613df3a
commit bafab049c2
3 changed files with 46 additions and 39 deletions

View File

@ -233,6 +233,7 @@ class GPT(nn.Module):
prompt=None, prompt=None,
get_attns=False, get_attns=False,
return_latent=False, return_latent=False,
attn_mask_cond=None,
attn_mask_text=None, attn_mask_text=None,
attn_mask_mel=None, attn_mask_mel=None,
): ):
@ -248,8 +249,12 @@ class GPT(nn.Module):
if attn_mask_text is not None: if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None: if prompt is not None:
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) if attn_mask_cond is not None:
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1) attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
else:
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
gpt_out = self.gpt( gpt_out = self.gpt(
inputs_embeds=emb, inputs_embeds=emb,
@ -326,7 +331,7 @@ class GPT(nn.Module):
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
return prompt return prompt
def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True): def get_style_emb(self, cond_input, return_latent=False):
""" """
cond_input: (b, 80, s) or (b, 1, 80, s) cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s) conds: (b, 1024, s)
@ -335,26 +340,7 @@ class GPT(nn.Module):
if not return_latent: if not return_latent:
if cond_input.ndim == 4: if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1) cond_input = cond_input.squeeze(1)
if sample: conds = self.conditioning_encoder(cond_input)
_len_secs = random.randint(2, 6) # in secs
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
if cond_input.shape[-1] >= cond_seg_len:
new_conds = []
for i in range(cond_input.shape[0]):
cond_len = int(cond_lens[i] / 1024)
if cond_len < cond_seg_len:
start = 0
else:
start = random.randint(0, cond_len - cond_seg_len)
cond_vec = cond_input[i, :, start : start + cond_seg_len]
new_conds.append(cond_vec)
conds = torch.stack(new_conds, dim=0)
else:
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
cond_frame_len = int((22050 / 1024) * cond_seg_len)
conds = cond_input[:, :, -cond_frame_len:]
conds = self.conditioning_encoder(conds)
else: else:
# already computed # already computed
conds = cond_input.unsqueeze(1) conds = cond_input.unsqueeze(1)
@ -366,10 +352,9 @@ class GPT(nn.Module):
text_lengths, text_lengths,
audio_codes, audio_codes,
wav_lengths, wav_lengths,
cond_lens=None,
cond_mels=None, cond_mels=None,
cond_idxs=None,
cond_latents=None, cond_latents=None,
loss_weights=None,
return_attentions=False, return_attentions=False,
return_latent=False, return_latent=False,
): ):
@ -377,11 +362,12 @@ class GPT(nn.Module):
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
cond_mels: MEL float tensor, (b, 1, 80,s)
text_inputs: long tensor, (b,t) text_inputs: long tensor, (b,t)
text_lengths: long tensor, (b,) text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, 1, 80,s)
cond_idxs: cond start and end indexs, (b, 2)
If return_attentions is specified, only logits are returned. If return_attentions is specified, only logits are returned.
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
@ -393,6 +379,11 @@ class GPT(nn.Module):
max_text_len = text_lengths.max() max_text_len = text_lengths.max()
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3 code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
if cond_idxs is not None:
# recompute cond idxs for mel lengths
for idx, l in enumerate(code_lengths):
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
max_mel_len = code_lengths.max() max_mel_len = code_lengths.max()
@ -435,9 +426,16 @@ class GPT(nn.Module):
) )
# Set attn_mask # Set attn_mask
attn_mask_cond = None
attn_mask_text = None attn_mask_text = None
attn_mask_mel = None attn_mask_mel = None
if not return_latent: if not return_latent:
attn_mask_cond = torch.ones(
cond_mels.shape[0],
cond_mels.shape[-1],
dtype=torch.bool,
device=text_inputs.device,
)
attn_mask_text = torch.ones( attn_mask_text = torch.ones(
text_inputs.shape[0], text_inputs.shape[0],
text_inputs.shape[1], text_inputs.shape[1],
@ -451,6 +449,11 @@ class GPT(nn.Module):
device=audio_codes.device, device=audio_codes.device,
) )
if cond_idxs is not None:
for idx, r in enumerate(cond_idxs.squeeze()):
l = r[1] - r[0]
attn_mask_cond[idx, l : ] = 0.0
for idx, l in enumerate(text_lengths): for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0 attn_mask_text[idx, l + 1 :] = 0.0
@ -465,7 +468,7 @@ class GPT(nn.Module):
# Compute speech conditioning input # Compute speech conditioning input
if cond_latents is None: if cond_latents is None:
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2) cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
# Get logits # Get logits
sub = -5 # don't ask me why 😄 sub = -5 # don't ask me why 😄
@ -480,6 +483,7 @@ class GPT(nn.Module):
prompt=cond_latents, prompt=cond_latents,
get_attns=return_attentions, get_attns=return_attentions,
return_latent=return_latent, return_latent=return_latent,
attn_mask_cond=attn_mask_cond,
attn_mask_text=attn_mask_text, attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel, attn_mask_mel=attn_mask_mel,
) )
@ -495,12 +499,19 @@ class GPT(nn.Module):
for idx, l in enumerate(code_lengths): for idx, l in enumerate(code_lengths):
mel_targets[idx, l + 1 :] = -1 mel_targets[idx, l + 1 :] = -1
# check if stoptoken is in every row of mel_targets # check if stoptoken is in every row of mel_targets
assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[ assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[
0 0
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row." ], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
# ignore the loss for the segment used for conditioning
# coin flip for the segment to be ignored
if cond_idxs is not None:
cond_start = cond_idxs[idx, 0]
cond_end = cond_idxs[idx, 1]
mel_targets[idx, cond_start:cond_end] = -1
# Compute losses # Compute losses
loss_text = F.cross_entropy( loss_text = F.cross_entropy(
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing

View File

@ -36,7 +36,6 @@ class GPTConfig(TortoiseConfig):
lr: float = 5e-06 lr: float = 5e-06
training_seed: int = 1 training_seed: int = 1
optimizer_wd_only_on_weights: bool = False optimizer_wd_only_on_weights: bool = False
use_weighted_loss: bool = False # TODO: move it to the base config
weighted_loss_attrs: dict = field(default_factory=lambda: {}) weighted_loss_attrs: dict = field(default_factory=lambda: {})
weighted_loss_multipliers: dict = field(default_factory=lambda: {}) weighted_loss_multipliers: dict = field(default_factory=lambda: {})
@ -200,7 +199,7 @@ class GPTTrainer(BaseTTS):
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens): def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
""" """
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
(actuated by `text_first`). (actuated by `text_first`).
@ -209,10 +208,10 @@ class GPTTrainer(BaseTTS):
text_lengths: long tensor, (b,) text_lengths: long tensor, (b,)
mel_inputs: long tensor, (b,m) mel_inputs: long tensor, (b,m)
wav_lengths: long tensor, (b,) wav_lengths: long tensor, (b,)
cond_mels: MEL float tensor, (b, num_samples, 80,t_m) cond_idxs: cond start and end indexs, (b, 2)
cond_lengths: long tensor, (b,) cond_lengths: long tensor, (b,)
""" """
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_lens=cond_lens) losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
return losses return losses
@torch.no_grad() @torch.no_grad()
@ -228,7 +227,6 @@ class GPTTrainer(BaseTTS):
batch["text_lengths"] = batch["text_lengths"] batch["text_lengths"] = batch["text_lengths"]
batch["wav_lengths"] = batch["wav_lengths"] batch["wav_lengths"] = batch["wav_lengths"]
batch["text_inputs"] = batch["padded_text"] batch["text_inputs"] = batch["padded_text"]
batch["cond_lens"] = batch["cond_lens"]
batch["cond_idxs"] = batch["cond_idxs"] batch["cond_idxs"] = batch["cond_idxs"]
# compute conditioning mel specs # compute conditioning mel specs
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
@ -261,7 +259,7 @@ class GPTTrainer(BaseTTS):
del batch["padded_text"] del batch["padded_text"]
del batch["wav"] del batch["wav"]
del batch["conditioning"] del batch["conditioning"]
del batch["cond_lens"]
return batch return batch
def train_step(self, batch, criterion): def train_step(self, batch, criterion):
@ -272,11 +270,10 @@ class GPTTrainer(BaseTTS):
audio_codes = batch["audio_codes"] audio_codes = batch["audio_codes"]
wav_lengths = batch["wav_lengths"] wav_lengths = batch["wav_lengths"]
cond_lens=batch["cond_lens"]
# Todo: implement masking on the cond slice # Todo: implement masking on the cond slice
cond_idxs = batch["cond_idxs"] cond_idxs = batch["cond_idxs"]
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens) loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
@ -325,7 +322,6 @@ class GPTTrainer(BaseTTS):
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
else: else:
# Todo: remove the randomness of dataset when it is eval
# init dataloader # init dataloader
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval) dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)

View File

@ -253,7 +253,7 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it] # DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it] DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
def freeze_layers(trainer): def freeze_layers(trainer):
pass pass
@ -262,7 +262,7 @@ def main():
model_args = GPTArgs( model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs min_conditioning_length=66150, # 3 secs
debug_loading_failures=True, debug_loading_failures=False,
max_wav_length=255995, # ~11.6 seconds max_wav_length=255995, # ~11.6 seconds
max_text_length=200, max_text_length=200,
tokenizer_file="/raid/datasets/xtts_models/vocab.json", tokenizer_file="/raid/datasets/xtts_models/vocab.json",