mirror of https://github.com/coqui-ai/TTS.git
Add prompting masking
This commit is contained in:
parent
47d613df3a
commit
bafab049c2
|
@ -233,6 +233,7 @@ class GPT(nn.Module):
|
|||
prompt=None,
|
||||
get_attns=False,
|
||||
return_latent=False,
|
||||
attn_mask_cond=None,
|
||||
attn_mask_text=None,
|
||||
attn_mask_mel=None,
|
||||
):
|
||||
|
@ -248,8 +249,12 @@ class GPT(nn.Module):
|
|||
if attn_mask_text is not None:
|
||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||
if prompt is not None:
|
||||
attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
|
||||
if attn_mask_cond is not None:
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
else:
|
||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
|
@ -326,7 +331,7 @@ class GPT(nn.Module):
|
|||
prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
|
||||
return prompt
|
||||
|
||||
def get_style_emb(self, cond_input, cond_lens=None, cond_seg_len=None, return_latent=False, sample=True):
|
||||
def get_style_emb(self, cond_input, return_latent=False):
|
||||
"""
|
||||
cond_input: (b, 80, s) or (b, 1, 80, s)
|
||||
conds: (b, 1024, s)
|
||||
|
@ -335,26 +340,7 @@ class GPT(nn.Module):
|
|||
if not return_latent:
|
||||
if cond_input.ndim == 4:
|
||||
cond_input = cond_input.squeeze(1)
|
||||
if sample:
|
||||
_len_secs = random.randint(2, 6) # in secs
|
||||
cond_seg_len = int((22050 / 1024) * _len_secs) # in frames
|
||||
if cond_input.shape[-1] >= cond_seg_len:
|
||||
new_conds = []
|
||||
for i in range(cond_input.shape[0]):
|
||||
cond_len = int(cond_lens[i] / 1024)
|
||||
if cond_len < cond_seg_len:
|
||||
start = 0
|
||||
else:
|
||||
start = random.randint(0, cond_len - cond_seg_len)
|
||||
cond_vec = cond_input[i, :, start : start + cond_seg_len]
|
||||
new_conds.append(cond_vec)
|
||||
conds = torch.stack(new_conds, dim=0)
|
||||
else:
|
||||
cond_seg_len = 5 if cond_seg_len is None else cond_seg_len # secs
|
||||
cond_frame_len = int((22050 / 1024) * cond_seg_len)
|
||||
conds = cond_input[:, :, -cond_frame_len:]
|
||||
|
||||
conds = self.conditioning_encoder(conds)
|
||||
conds = self.conditioning_encoder(cond_input)
|
||||
else:
|
||||
# already computed
|
||||
conds = cond_input.unsqueeze(1)
|
||||
|
@ -366,10 +352,9 @@ class GPT(nn.Module):
|
|||
text_lengths,
|
||||
audio_codes,
|
||||
wav_lengths,
|
||||
cond_lens=None,
|
||||
cond_mels=None,
|
||||
cond_idxs=None,
|
||||
cond_latents=None,
|
||||
loss_weights=None,
|
||||
return_attentions=False,
|
||||
return_latent=False,
|
||||
):
|
||||
|
@ -377,11 +362,12 @@ class GPT(nn.Module):
|
|||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
||||
cond_mels: MEL float tensor, (b, 1, 80,s)
|
||||
text_inputs: long tensor, (b,t)
|
||||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
cond_mels: MEL float tensor, (b, 1, 80,s)
|
||||
cond_idxs: cond start and end indexs, (b, 2)
|
||||
|
||||
If return_attentions is specified, only logits are returned.
|
||||
If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
|
||||
|
@ -393,6 +379,11 @@ class GPT(nn.Module):
|
|||
max_text_len = text_lengths.max()
|
||||
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
||||
|
||||
if cond_idxs is not None:
|
||||
# recompute cond idxs for mel lengths
|
||||
for idx, l in enumerate(code_lengths):
|
||||
cond_idxs[idx] = cond_idxs[idx] / self.code_stride_len
|
||||
|
||||
# If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
|
||||
max_mel_len = code_lengths.max()
|
||||
|
||||
|
@ -435,9 +426,16 @@ class GPT(nn.Module):
|
|||
)
|
||||
|
||||
# Set attn_mask
|
||||
attn_mask_cond = None
|
||||
attn_mask_text = None
|
||||
attn_mask_mel = None
|
||||
if not return_latent:
|
||||
attn_mask_cond = torch.ones(
|
||||
cond_mels.shape[0],
|
||||
cond_mels.shape[-1],
|
||||
dtype=torch.bool,
|
||||
device=text_inputs.device,
|
||||
)
|
||||
attn_mask_text = torch.ones(
|
||||
text_inputs.shape[0],
|
||||
text_inputs.shape[1],
|
||||
|
@ -451,6 +449,11 @@ class GPT(nn.Module):
|
|||
device=audio_codes.device,
|
||||
)
|
||||
|
||||
if cond_idxs is not None:
|
||||
for idx, r in enumerate(cond_idxs.squeeze()):
|
||||
l = r[1] - r[0]
|
||||
attn_mask_cond[idx, l : ] = 0.0
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
||||
|
@ -465,7 +468,7 @@ class GPT(nn.Module):
|
|||
|
||||
# Compute speech conditioning input
|
||||
if cond_latents is None:
|
||||
cond_latents = self.get_style_emb(cond_mels, cond_lens).transpose(1, 2)
|
||||
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
|
||||
|
||||
# Get logits
|
||||
sub = -5 # don't ask me why 😄
|
||||
|
@ -480,6 +483,7 @@ class GPT(nn.Module):
|
|||
prompt=cond_latents,
|
||||
get_attns=return_attentions,
|
||||
return_latent=return_latent,
|
||||
attn_mask_cond=attn_mask_cond,
|
||||
attn_mask_text=attn_mask_text,
|
||||
attn_mask_mel=attn_mask_mel,
|
||||
)
|
||||
|
@ -495,12 +499,19 @@ class GPT(nn.Module):
|
|||
|
||||
for idx, l in enumerate(code_lengths):
|
||||
mel_targets[idx, l + 1 :] = -1
|
||||
|
||||
|
||||
# check if stoptoken is in every row of mel_targets
|
||||
assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[
|
||||
0
|
||||
], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row."
|
||||
|
||||
# ignore the loss for the segment used for conditioning
|
||||
# coin flip for the segment to be ignored
|
||||
if cond_idxs is not None:
|
||||
cond_start = cond_idxs[idx, 0]
|
||||
cond_end = cond_idxs[idx, 1]
|
||||
mel_targets[idx, cond_start:cond_end] = -1
|
||||
|
||||
# Compute losses
|
||||
loss_text = F.cross_entropy(
|
||||
text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
|
||||
|
|
|
@ -36,7 +36,6 @@ class GPTConfig(TortoiseConfig):
|
|||
lr: float = 5e-06
|
||||
training_seed: int = 1
|
||||
optimizer_wd_only_on_weights: bool = False
|
||||
use_weighted_loss: bool = False # TODO: move it to the base config
|
||||
weighted_loss_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
||||
|
||||
|
@ -200,7 +199,7 @@ class GPTTrainer(BaseTTS):
|
|||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens):
|
||||
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
@ -209,10 +208,10 @@ class GPTTrainer(BaseTTS):
|
|||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||
cond_idxs: cond start and end indexs, (b, 2)
|
||||
cond_lengths: long tensor, (b,)
|
||||
"""
|
||||
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_lens=cond_lens)
|
||||
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -228,7 +227,6 @@ class GPTTrainer(BaseTTS):
|
|||
batch["text_lengths"] = batch["text_lengths"]
|
||||
batch["wav_lengths"] = batch["wav_lengths"]
|
||||
batch["text_inputs"] = batch["padded_text"]
|
||||
batch["cond_lens"] = batch["cond_lens"]
|
||||
batch["cond_idxs"] = batch["cond_idxs"]
|
||||
# compute conditioning mel specs
|
||||
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
|
||||
|
@ -261,7 +259,7 @@ class GPTTrainer(BaseTTS):
|
|||
del batch["padded_text"]
|
||||
del batch["wav"]
|
||||
del batch["conditioning"]
|
||||
|
||||
del batch["cond_lens"]
|
||||
return batch
|
||||
|
||||
def train_step(self, batch, criterion):
|
||||
|
@ -272,11 +270,10 @@ class GPTTrainer(BaseTTS):
|
|||
audio_codes = batch["audio_codes"]
|
||||
wav_lengths = batch["wav_lengths"]
|
||||
|
||||
cond_lens=batch["cond_lens"]
|
||||
# Todo: implement masking on the cond slice
|
||||
cond_idxs = batch["cond_idxs"]
|
||||
|
||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens)
|
||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
|
||||
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
|
||||
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
|
||||
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
|
||||
|
@ -325,7 +322,6 @@ class GPTTrainer(BaseTTS):
|
|||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# Todo: remove the randomness of dataset when it is eval
|
||||
# init dataloader
|
||||
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)
|
||||
|
||||
|
|
|
@ -253,7 +253,7 @@ config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
|
|||
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
||||
|
||||
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
||||
|
||||
|
||||
def freeze_layers(trainer):
|
||||
pass
|
||||
|
||||
|
@ -262,7 +262,7 @@ def main():
|
|||
model_args = GPTArgs(
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=True,
|
||||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
tokenizer_file="/raid/datasets/xtts_models/vocab.json",
|
||||
|
|
Loading…
Reference in New Issue