refactor(xtts): use tortoise conditioning encoder

This commit is contained in:
Enno Hermann 2024-11-22 15:38:35 +01:00
parent 69a599d403
commit 6ecf47312c
4 changed files with 7 additions and 65 deletions

View File

@ -176,7 +176,6 @@ class ConditioningEncoder(nn.Module):
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
mean=False,
):
super().__init__()
attn = []
@ -185,15 +184,14 @@ class ConditioningEncoder(nn.Module):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
self.mean = mean
def forward(self, x):
"""
x: (b, 80, s)
"""
h = self.init(x)
h = self.attn(h)
if self.mean:
return h.mean(dim=2)
else:
return h[:, :, 0]
return h
class LearnedPositionEmbeddings(nn.Module):
@ -473,7 +471,7 @@ class UnifiedVoice(nn.Module):
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])[:, :, 0])
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds

View File

@ -8,12 +8,12 @@ import torch.nn.functional as F
from transformers import GPT2Config
from TTS.tts.layers.tortoise.autoregressive import (
ConditioningEncoder,
LearnedPositionEmbeddings,
_prepare_attention_mask_for_generation,
build_hf_gpt_transformer,
)
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
@ -235,19 +235,6 @@ class GPT(nn.Module):
else:
return first_logits
def get_conditioning(self, speech_conditioning_input):
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds
def get_prompts(self, prompt_codes):
"""
Create a prompt from the mel codes. This is used to condition the model on the mel codes.
@ -286,6 +273,7 @@ class GPT(nn.Module):
"""
cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s)
output: (b, 1024, 32)
"""
conds = None
if not return_latent:

View File

@ -93,28 +93,3 @@ class AttentionBlock(nn.Module):
h = self.proj_out(h)
xp = self.x_proj(x)
return (xp + h).reshape(b, xp.shape[1], *spatial)
class ConditioningEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
"""
x: (b, 80, s)
"""
h = self.init(x)
h = self.attn(h)
return h

View File

@ -93,25 +93,6 @@ def load_audio(audiopath, sampling_rate):
return audio
def pad_or_truncate(t, length):
"""
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
Args:
t (torch.Tensor): The input tensor to be padded or truncated.
length (int): The desired length of the tensor.
Returns:
torch.Tensor: The padded or truncated tensor.
"""
tp = t[..., :length]
if t.shape[-1] == length:
tp = t
elif t.shape[-1] < length:
tp = F.pad(t, (0, length - t.shape[-1]))
return tp
@dataclass
class XttsAudioConfig(Coqpit):
"""