mirror of https://github.com/coqui-ai/TTS.git
Update GPT Trainer for perceiver support
This commit is contained in:
parent
dff3902ca8
commit
8479a3702c
|
@ -494,10 +494,9 @@ class GPT(nn.Module):
|
||||||
|
|
||||||
# Compute speech conditioning input
|
# Compute speech conditioning input
|
||||||
if cond_latents is None:
|
if cond_latents is None:
|
||||||
if cond_lens is not None:
|
if cond_lens is not None and cond_idxs is None:
|
||||||
min_cond_len = torch.min(cond_lens)
|
min_cond_len = torch.min(cond_lens)
|
||||||
cond_mels = cond_mels[:, :, :, :min_cond_len]
|
cond_mels = cond_mels[:, :, :, :min_cond_len]
|
||||||
|
|
||||||
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
|
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
|
||||||
|
|
||||||
# Get logits
|
# Get logits
|
||||||
|
|
|
@ -88,7 +88,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.max_wav_len = model_args.max_wav_length
|
self.max_wav_len = model_args.max_wav_length
|
||||||
self.max_text_len = model_args.max_text_length
|
self.max_text_len = model_args.max_text_length
|
||||||
self.use_masking_gt_as_prompt = model_args.use_masking_gt_as_prompt
|
self.use_masking_gt_as_prompt = model_args.gpt_use_masking_gt_as_prompt
|
||||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||||
|
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
|
@ -143,16 +143,18 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
if self.use_masking_gt_as_prompt:
|
if self.use_masking_gt_as_prompt:
|
||||||
# get a slice from GT to condition the model
|
# get a slice from GT to condition the model
|
||||||
cond, cond_len, cond_idxs = get_prompt_slice(
|
cond, _, cond_idxs = get_prompt_slice(
|
||||||
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||||
)
|
)
|
||||||
|
# if use masking do not use cond_len
|
||||||
|
cond_len = torch.nan
|
||||||
else:
|
else:
|
||||||
ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath
|
ref_sample = sample["reference_path"] if "reference_path" in sample and sample["reference_path"] is not None else audiopath
|
||||||
cond, cond_len, cond_idxs = get_prompt_slice(
|
cond, cond_len, _ = get_prompt_slice(
|
||||||
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
||||||
)
|
)
|
||||||
|
# if do not use masking use cond_len
|
||||||
cond_idxs = torch.nan
|
cond_idxs = torch.nan
|
||||||
cond_len = torch.nan
|
|
||||||
|
|
||||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||||
|
|
||||||
|
@ -230,9 +232,12 @@ class XTTSDataset(torch.utils.data.Dataset):
|
||||||
batch["conditioning"] = torch.stack(batch["conditioning"])
|
batch["conditioning"] = torch.stack(batch["conditioning"])
|
||||||
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
||||||
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
||||||
|
|
||||||
if torch.any(batch["cond_idxs"].isnan()):
|
if torch.any(batch["cond_idxs"].isnan()):
|
||||||
batch["cond_lens"] = None
|
batch["cond_lens"] = None
|
||||||
batch["cond_idxs"] = None
|
|
||||||
|
if torch.any(batch["cond_lens"].isnan()):
|
||||||
|
batch["cond_lens"] = None
|
||||||
|
|
||||||
max_text_len = batch["text_lengths"].max()
|
max_text_len = batch["text_lengths"].max()
|
||||||
max_wav_len = batch["wav_lengths"].max()
|
max_wav_len = batch["wav_lengths"].max()
|
||||||
|
|
|
@ -52,7 +52,6 @@ class GPTArgs(XttsArgs):
|
||||||
xtts_checkpoint: str = ""
|
xtts_checkpoint: str = ""
|
||||||
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
||||||
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
||||||
use_masking_gt_as_prompt: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
def callback_clearml_load_save(operation_type, model_info):
|
def callback_clearml_load_save(operation_type, model_info):
|
||||||
|
@ -200,7 +199,7 @@ class GPTTrainer(BaseTTS):
|
||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs):
|
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens):
|
||||||
"""
|
"""
|
||||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||||
(actuated by `text_first`).
|
(actuated by `text_first`).
|
||||||
|
@ -211,9 +210,10 @@ class GPTTrainer(BaseTTS):
|
||||||
wav_lengths: long tensor, (b,)
|
wav_lengths: long tensor, (b,)
|
||||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||||
cond_idxs: cond start and end indexs, (b, 2)
|
cond_idxs: cond start and end indexs, (b, 2)
|
||||||
|
cond_lens: long tensor, (b,)
|
||||||
"""
|
"""
|
||||||
losses = self.xtts.gpt(
|
losses = self.xtts.gpt(
|
||||||
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs
|
text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs, cond_lens=cond_lens,
|
||||||
)
|
)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
@ -283,7 +283,6 @@ class GPTTrainer(BaseTTS):
|
||||||
del batch["padded_text"]
|
del batch["padded_text"]
|
||||||
del batch["wav"]
|
del batch["wav"]
|
||||||
del batch["conditioning"]
|
del batch["conditioning"]
|
||||||
del batch["cond_lens"]
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def train_step(self, batch, criterion):
|
def train_step(self, batch, criterion):
|
||||||
|
@ -294,8 +293,9 @@ class GPTTrainer(BaseTTS):
|
||||||
audio_codes = batch["audio_codes"]
|
audio_codes = batch["audio_codes"]
|
||||||
wav_lengths = batch["wav_lengths"]
|
wav_lengths = batch["wav_lengths"]
|
||||||
cond_idxs = batch["cond_idxs"]
|
cond_idxs = batch["cond_idxs"]
|
||||||
|
cond_lens = batch["cond_lens"]
|
||||||
|
|
||||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs)
|
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens)
|
||||||
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
|
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
|
||||||
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
|
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
|
||||||
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
|
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
|
||||||
|
|
|
@ -241,6 +241,7 @@ class XttsArgs(Coqpit):
|
||||||
gpt_num_audio_tokens: int = 8194
|
gpt_num_audio_tokens: int = 8194
|
||||||
gpt_start_audio_token: int = 8192
|
gpt_start_audio_token: int = 8192
|
||||||
gpt_stop_audio_token: int = 8193
|
gpt_stop_audio_token: int = 8193
|
||||||
|
gpt_use_masking_gt_as_prompt: bool = True
|
||||||
gpt_use_perceiver_resampler: bool = False
|
gpt_use_perceiver_resampler: bool = False
|
||||||
|
|
||||||
# Diffusion Decoder params
|
# Diffusion Decoder params
|
||||||
|
|
Loading…
Reference in New Issue