mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): use tortoise conditioning encoder
This commit is contained in:
parent
69a599d403
commit
6ecf47312c
|
@ -176,7 +176,6 @@ class ConditioningEncoder(nn.Module):
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
attn_blocks=6,
|
attn_blocks=6,
|
||||||
num_attn_heads=4,
|
num_attn_heads=4,
|
||||||
mean=False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
attn = []
|
attn = []
|
||||||
|
@ -185,15 +184,14 @@ class ConditioningEncoder(nn.Module):
|
||||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
||||||
self.attn = nn.Sequential(*attn)
|
self.attn = nn.Sequential(*attn)
|
||||||
self.dim = embedding_dim
|
self.dim = embedding_dim
|
||||||
self.mean = mean
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: (b, 80, s)
|
||||||
|
"""
|
||||||
h = self.init(x)
|
h = self.init(x)
|
||||||
h = self.attn(h)
|
h = self.attn(h)
|
||||||
if self.mean:
|
return h
|
||||||
return h.mean(dim=2)
|
|
||||||
else:
|
|
||||||
return h[:, :, 0]
|
|
||||||
|
|
||||||
|
|
||||||
class LearnedPositionEmbeddings(nn.Module):
|
class LearnedPositionEmbeddings(nn.Module):
|
||||||
|
@ -473,7 +471,7 @@ class UnifiedVoice(nn.Module):
|
||||||
)
|
)
|
||||||
conds = []
|
conds = []
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
for j in range(speech_conditioning_input.shape[1]):
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])[:, :, 0])
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
conds = conds.mean(dim=1)
|
conds = conds.mean(dim=1)
|
||||||
return conds
|
return conds
|
||||||
|
|
|
@ -8,12 +8,12 @@ import torch.nn.functional as F
|
||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.autoregressive import (
|
from TTS.tts.layers.tortoise.autoregressive import (
|
||||||
|
ConditioningEncoder,
|
||||||
LearnedPositionEmbeddings,
|
LearnedPositionEmbeddings,
|
||||||
_prepare_attention_mask_for_generation,
|
_prepare_attention_mask_for_generation,
|
||||||
build_hf_gpt_transformer,
|
build_hf_gpt_transformer,
|
||||||
)
|
)
|
||||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
|
||||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||||
|
|
||||||
|
|
||||||
|
@ -235,19 +235,6 @@ class GPT(nn.Module):
|
||||||
else:
|
else:
|
||||||
return first_logits
|
return first_logits
|
||||||
|
|
||||||
def get_conditioning(self, speech_conditioning_input):
|
|
||||||
speech_conditioning_input = (
|
|
||||||
speech_conditioning_input.unsqueeze(1)
|
|
||||||
if len(speech_conditioning_input.shape) == 3
|
|
||||||
else speech_conditioning_input
|
|
||||||
)
|
|
||||||
conds = []
|
|
||||||
for j in range(speech_conditioning_input.shape[1]):
|
|
||||||
conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
|
|
||||||
conds = torch.stack(conds, dim=1)
|
|
||||||
conds = conds.mean(dim=1)
|
|
||||||
return conds
|
|
||||||
|
|
||||||
def get_prompts(self, prompt_codes):
|
def get_prompts(self, prompt_codes):
|
||||||
"""
|
"""
|
||||||
Create a prompt from the mel codes. This is used to condition the model on the mel codes.
|
Create a prompt from the mel codes. This is used to condition the model on the mel codes.
|
||||||
|
@ -286,6 +273,7 @@ class GPT(nn.Module):
|
||||||
"""
|
"""
|
||||||
cond_input: (b, 80, s) or (b, 1, 80, s)
|
cond_input: (b, 80, s) or (b, 1, 80, s)
|
||||||
conds: (b, 1024, s)
|
conds: (b, 1024, s)
|
||||||
|
output: (b, 1024, 32)
|
||||||
"""
|
"""
|
||||||
conds = None
|
conds = None
|
||||||
if not return_latent:
|
if not return_latent:
|
||||||
|
|
|
@ -93,28 +93,3 @@ class AttentionBlock(nn.Module):
|
||||||
h = self.proj_out(h)
|
h = self.proj_out(h)
|
||||||
xp = self.x_proj(x)
|
xp = self.x_proj(x)
|
||||||
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
return (xp + h).reshape(b, xp.shape[1], *spatial)
|
||||||
|
|
||||||
|
|
||||||
class ConditioningEncoder(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
spec_dim,
|
|
||||||
embedding_dim,
|
|
||||||
attn_blocks=6,
|
|
||||||
num_attn_heads=4,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
attn = []
|
|
||||||
self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
|
|
||||||
for a in range(attn_blocks):
|
|
||||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads))
|
|
||||||
self.attn = nn.Sequential(*attn)
|
|
||||||
self.dim = embedding_dim
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
x: (b, 80, s)
|
|
||||||
"""
|
|
||||||
h = self.init(x)
|
|
||||||
h = self.attn(h)
|
|
||||||
return h
|
|
||||||
|
|
|
@ -93,25 +93,6 @@ def load_audio(audiopath, sampling_rate):
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
|
||||||
def pad_or_truncate(t, length):
|
|
||||||
"""
|
|
||||||
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
t (torch.Tensor): The input tensor to be padded or truncated.
|
|
||||||
length (int): The desired length of the tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The padded or truncated tensor.
|
|
||||||
"""
|
|
||||||
tp = t[..., :length]
|
|
||||||
if t.shape[-1] == length:
|
|
||||||
tp = t
|
|
||||||
elif t.shape[-1] < length:
|
|
||||||
tp = F.pad(t, (0, length - t.shape[-1]))
|
|
||||||
return tp
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class XttsAudioConfig(Coqpit):
|
class XttsAudioConfig(Coqpit):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue