mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): use tortoise conditioning encoder
This commit is contained in:
parent
69a599d403
commit
6ecf47312c
|
@ -176,7 +176,6 @@ class ConditioningEncoder(nn.Module):
|
|||
embedding_dim,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=4,
|
||||
mean=False,
|
||||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
|
@ -185,15 +184,14 @@ class ConditioningEncoder(nn.Module):
|
|||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
self.mean = mean
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (b, 80, s)
|
||||
"""
|
||||
h = self.init(x)
|
||||
h = self.attn(h)
|
||||
if self.mean:
|
||||
return h.mean(dim=2)
|
||||
else:
|
||||
return h[:, :, 0]
|
||||
return h
|
||||
|
||||
|
||||
class LearnedPositionEmbeddings(nn.Module):
|
||||
|
@ -473,7 +471,7 @@ class UnifiedVoice(nn.Module):
|
|||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])[:, :, 0])
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
|
|
@ -8,12 +8,12 @@ import torch.nn.functional as F
|
|||
from transformers import GPT2Config
|
||||
|
||||
from TTS.tts.layers.tortoise.autoregressive import (
|
||||
ConditioningEncoder,
|
||||
LearnedPositionEmbeddings,
|
||||
_prepare_attention_mask_for_generation,
|
||||
build_hf_gpt_transformer,
|
||||
)
|
||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||
|
||||
|
||||
|
@ -235,19 +235,6 @@ class GPT(nn.Module):
|
|||
else:
|
||||
return first_logits
|
||||
|
||||
def get_conditioning(self, speech_conditioning_input):
|
||||
speech_conditioning_input = (
|
||||
speech_conditioning_input.unsqueeze(1)
|
||||
if len(speech_conditioning_input.shape) == 3
|
||||
else speech_conditioning_input
|
||||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
||||
conds = torch.stack(conds, dim=1)
|
||||
conds = conds.mean(dim=1)
|
||||
return conds
|
||||
|
||||
def get_prompts(self, prompt_codes):
|
||||
"""
|
||||
Create a prompt from the mel codes. This is used to condition the model on the mel codes.
|
||||
|
@ -286,6 +273,7 @@ class GPT(nn.Module):
|
|||
"""
|
||||
cond_input: (b, 80, s) or (b, 1, 80, s)
|
||||
conds: (b, 1024, s)
|
||||
output: (b, 1024, 32)
|
||||
"""
|
||||
conds = None
|
||||
if not return_latent:
|
||||
|
|
|
@ -93,28 +93,3 @@ class AttentionBlock(nn.Module):
|
|||
h = self.proj_out(h)
|
||||
xp = self.x_proj(x)
|
||||
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
||||
|
||||
|
||||
class ConditioningEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_dim,
|
||||
embedding_dim,
|
||||
attn_blocks=6,
|
||||
num_attn_heads=4,
|
||||
):
|
||||
super().__init__()
|
||||
attn = []
|
||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (b, 80, s)
|
||||
"""
|
||||
h = self.init(x)
|
||||
h = self.attn(h)
|
||||
return h
|
||||
|
|
|
@ -93,25 +93,6 @@ def load_audio(audiopath, sampling_rate):
|
|||
return audio
|
||||
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
"""
|
||||
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): The input tensor to be padded or truncated.
|
||||
length (int): The desired length of the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The padded or truncated tensor.
|
||||
"""
|
||||
tp = t[..., :length]
|
||||
if t.shape[-1] == length:
|
||||
tp = t
|
||||
elif t.shape[-1] < length:
|
||||
tp = F.pad(t, (0, length - t.shape[-1]))
|
||||
return tp
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(Coqpit):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue