refactor(xtts): use tortoise conditioning encoder

This commit is contained in:
Enno Hermann 2024-11-22 15:38:35 +01:00
parent 69a599d403
commit 6ecf47312c
4 changed files with 7 additions and 65 deletions

View File

@ -176,7 +176,6 @@ class ConditioningEncoder(nn.Module):
embedding_dim, embedding_dim,
attn_blocks=6, attn_blocks=6,
num_attn_heads=4, num_attn_heads=4,
mean=False,
): ):
super().__init__() super().__init__()
attn = [] attn = []
@ -185,15 +184,14 @@ class ConditioningEncoder(nn.Module):
attn.append(AttentionBlock(embedding_dim, num_attn_heads)) attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.dim = embedding_dim self.dim = embedding_dim
self.mean = mean
def forward(self, x): def forward(self, x):
"""
x: (b, 80, s)
"""
h = self.init(x) h = self.init(x)
h = self.attn(h) h = self.attn(h)
if self.mean: return h
return h.mean(dim=2)
else:
return h[:, :, 0]
class LearnedPositionEmbeddings(nn.Module): class LearnedPositionEmbeddings(nn.Module):
@ -473,7 +471,7 @@ class UnifiedVoice(nn.Module):
) )
conds = [] conds = []
for j in range(speech_conditioning_input.shape[1]): for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])[:, :, 0])
conds = torch.stack(conds, dim=1) conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1) conds = conds.mean(dim=1)
return conds return conds

View File

@ -8,12 +8,12 @@ import torch.nn.functional as F
from transformers import GPT2Config from transformers import GPT2Config
from TTS.tts.layers.tortoise.autoregressive import ( from TTS.tts.layers.tortoise.autoregressive import (
ConditioningEncoder,
LearnedPositionEmbeddings, LearnedPositionEmbeddings,
_prepare_attention_mask_for_generation, _prepare_attention_mask_for_generation,
build_hf_gpt_transformer, build_hf_gpt_transformer,
) )
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
@ -235,19 +235,6 @@ class GPT(nn.Module):
else: else:
return first_logits return first_logits
def get_conditioning(self, speech_conditioning_input):
speech_conditioning_input = (
speech_conditioning_input.unsqueeze(1)
if len(speech_conditioning_input.shape) == 3
else speech_conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
conds = torch.stack(conds, dim=1)
conds = conds.mean(dim=1)
return conds
def get_prompts(self, prompt_codes): def get_prompts(self, prompt_codes):
""" """
Create a prompt from the mel codes. This is used to condition the model on the mel codes. Create a prompt from the mel codes. This is used to condition the model on the mel codes.
@ -286,6 +273,7 @@ class GPT(nn.Module):
""" """
cond_input: (b, 80, s) or (b, 1, 80, s) cond_input: (b, 80, s) or (b, 1, 80, s)
conds: (b, 1024, s) conds: (b, 1024, s)
output: (b, 1024, 32)
""" """
conds = None conds = None
if not return_latent: if not return_latent:

View File

@ -93,28 +93,3 @@ class AttentionBlock(nn.Module):
h = self.proj_out(h) h = self.proj_out(h)
xp = self.x_proj(x) xp = self.x_proj(x)
return (xp + h).reshape(b, xp.shape[1], *spatial) return (xp + h).reshape(b, xp.shape[1], *spatial)
class ConditioningEncoder(nn.Module):
def __init__(
self,
spec_dim,
embedding_dim,
attn_blocks=6,
num_attn_heads=4,
):
super().__init__()
attn = []
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
def forward(self, x):
"""
x: (b, 80, s)
"""
h = self.init(x)
h = self.attn(h)
return h

View File

@ -93,25 +93,6 @@ def load_audio(audiopath, sampling_rate):
return audio return audio
def pad_or_truncate(t, length):
"""
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
Args:
t (torch.Tensor): The input tensor to be padded or truncated.
length (int): The desired length of the tensor.
Returns:
torch.Tensor: The padded or truncated tensor.
"""
tp = t[..., :length]
if t.shape[-1] == length:
tp = t
elif t.shape[-1] < length:
tp = F.pad(t, (0, length - t.shape[-1]))
return tp
@dataclass @dataclass
class XttsAudioConfig(Coqpit): class XttsAudioConfig(Coqpit):
""" """