mirror of https://github.com/coqui-ai/TTS.git
Add Perceiver
This commit is contained in:
parent
1fb6c203ab
commit
dff3902ca8
|
@ -11,7 +11,7 @@ from transformers import GPT2Config
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||||
|
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||||
|
|
||||||
def null_position_embeddings(range, dim):
|
def null_position_embeddings(range, dim):
|
||||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||||
|
@ -105,6 +105,8 @@ class GPT(nn.Module):
|
||||||
checkpointing=False,
|
checkpointing=False,
|
||||||
average_conditioning_embeddings=False,
|
average_conditioning_embeddings=False,
|
||||||
label_smoothing=0.0,
|
label_smoothing=0.0,
|
||||||
|
use_perceiver_resampler=False,
|
||||||
|
perceiver_cond_length_compression=256,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -132,6 +134,8 @@ class GPT(nn.Module):
|
||||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||||
self.conditioning_dropout = nn.Dropout1d(0.1)
|
self.conditioning_dropout = nn.Dropout1d(0.1)
|
||||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||||
|
self.use_perceiver_resampler = use_perceiver_resampler
|
||||||
|
self.perceiver_cond_length_compression = perceiver_cond_length_compression
|
||||||
|
|
||||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||||
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||||
|
@ -165,9 +169,22 @@ class GPT(nn.Module):
|
||||||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||||
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
||||||
|
|
||||||
|
if self.use_perceiver_resampler:
|
||||||
|
self.conditioning_perceiver = PerceiverResampler(
|
||||||
|
dim=model_dim,
|
||||||
|
depth=2,
|
||||||
|
dim_context=model_dim,
|
||||||
|
num_latents=32,
|
||||||
|
dim_head=64,
|
||||||
|
heads=8,
|
||||||
|
ff_mult=4,
|
||||||
|
use_flash_attn=False,
|
||||||
|
)
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
return {
|
return {
|
||||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||||
|
"conditioning_perceiver": list(self.conditioning_perceiver.parameters()) if self.use_perceiver_resampler else None,
|
||||||
"gpt": list(self.gpt.parameters()),
|
"gpt": list(self.gpt.parameters()),
|
||||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||||
}
|
}
|
||||||
|
@ -250,9 +267,6 @@ class GPT(nn.Module):
|
||||||
if attn_mask_text is not None:
|
if attn_mask_text is not None:
|
||||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
if attn_mask_cond is not None:
|
|
||||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
|
||||||
else:
|
|
||||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||||
|
|
||||||
|
@ -318,7 +332,6 @@ class GPT(nn.Module):
|
||||||
prompt_len = 3
|
prompt_len = 3
|
||||||
prompt_len = prompt_len * 24 # in frames
|
prompt_len = prompt_len * 24 # in frames
|
||||||
if prompt_codes.shape[-1] >= prompt_len:
|
if prompt_codes.shape[-1] >= prompt_len:
|
||||||
new_prompt = []
|
|
||||||
for i in range(prompt_codes.shape[0]):
|
for i in range(prompt_codes.shape[0]):
|
||||||
if lengths[i] < prompt_len:
|
if lengths[i] < prompt_len:
|
||||||
start = 0
|
start = 0
|
||||||
|
@ -340,7 +353,9 @@ class GPT(nn.Module):
|
||||||
if not return_latent:
|
if not return_latent:
|
||||||
if cond_input.ndim == 4:
|
if cond_input.ndim == 4:
|
||||||
cond_input = cond_input.squeeze(1)
|
cond_input = cond_input.squeeze(1)
|
||||||
conds = self.conditioning_encoder(cond_input)
|
conds = self.conditioning_encoder(cond_input) # (b, d, s)
|
||||||
|
if self.use_perceiver_resampler:
|
||||||
|
conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32)
|
||||||
else:
|
else:
|
||||||
# already computed
|
# already computed
|
||||||
conds = cond_input.unsqueeze(1)
|
conds = cond_input.unsqueeze(1)
|
||||||
|
@ -354,6 +369,7 @@ class GPT(nn.Module):
|
||||||
wav_lengths,
|
wav_lengths,
|
||||||
cond_mels=None,
|
cond_mels=None,
|
||||||
cond_idxs=None,
|
cond_idxs=None,
|
||||||
|
cond_lens=None,
|
||||||
cond_latents=None,
|
cond_latents=None,
|
||||||
return_attentions=False,
|
return_attentions=False,
|
||||||
return_latent=False,
|
return_latent=False,
|
||||||
|
@ -379,6 +395,12 @@ class GPT(nn.Module):
|
||||||
max_text_len = text_lengths.max()
|
max_text_len = text_lengths.max()
|
||||||
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
||||||
|
|
||||||
|
if cond_lens is not None:
|
||||||
|
if self.use_perceiver_resampler:
|
||||||
|
cond_lens = cond_lens // self.perceiver_cond_length_compression
|
||||||
|
else:
|
||||||
|
cond_lens = cond_lens // self.code_stride_len
|
||||||
|
|
||||||
if cond_idxs is not None:
|
if cond_idxs is not None:
|
||||||
# recompute cond idxs for mel lengths
|
# recompute cond idxs for mel lengths
|
||||||
for idx, l in enumerate(code_lengths):
|
for idx, l in enumerate(code_lengths):
|
||||||
|
@ -450,9 +472,13 @@ class GPT(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if cond_idxs is not None:
|
if cond_idxs is not None:
|
||||||
|
# use masking approach
|
||||||
for idx, r in enumerate(cond_idxs):
|
for idx, r in enumerate(cond_idxs):
|
||||||
l = r[1] - r[0]
|
l = r[1] - r[0]
|
||||||
attn_mask_cond[idx, l:] = 0.0
|
attn_mask_cond[idx, l:] = 0.0
|
||||||
|
elif cond_lens is not None:
|
||||||
|
for idx, l in enumerate(cond_lens):
|
||||||
|
attn_mask_cond[idx, l:] = 0.0
|
||||||
|
|
||||||
for idx, l in enumerate(text_lengths):
|
for idx, l in enumerate(text_lengths):
|
||||||
attn_mask_text[idx, l + 1 :] = 0.0
|
attn_mask_text[idx, l + 1 :] = 0.0
|
||||||
|
@ -468,6 +494,10 @@ class GPT(nn.Module):
|
||||||
|
|
||||||
# Compute speech conditioning input
|
# Compute speech conditioning input
|
||||||
if cond_latents is None:
|
if cond_latents is None:
|
||||||
|
if cond_lens is not None:
|
||||||
|
min_cond_len = torch.min(cond_lens)
|
||||||
|
cond_mels = cond_mels[:, :, :, :min_cond_len]
|
||||||
|
|
||||||
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
|
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
|
||||||
|
|
||||||
# Get logits
|
# Get logits
|
||||||
|
@ -483,7 +513,7 @@ class GPT(nn.Module):
|
||||||
prompt=cond_latents,
|
prompt=cond_latents,
|
||||||
get_attns=return_attentions,
|
get_attns=return_attentions,
|
||||||
return_latent=return_latent,
|
return_latent=return_latent,
|
||||||
attn_mask_cond=attn_mask_cond,
|
attn_mask_cond=attn_mask_cond if not self.use_perceiver_resampler else None,
|
||||||
attn_mask_text=attn_mask_text,
|
attn_mask_text=attn_mask_text,
|
||||||
attn_mask_mel=attn_mask_mel,
|
attn_mask_mel=attn_mask_mel,
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,321 @@
|
||||||
|
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, einsum
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import wraps
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from einops.layers.torch import Rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def once(fn):
|
||||||
|
called = False
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(x):
|
||||||
|
nonlocal called
|
||||||
|
if called:
|
||||||
|
return
|
||||||
|
called = True
|
||||||
|
return fn(x)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
print_once = once(print)
|
||||||
|
|
||||||
|
# main class
|
||||||
|
|
||||||
|
class Attend(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dropout = 0.,
|
||||||
|
causal = False,
|
||||||
|
use_flash = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dropout = dropout
|
||||||
|
self.attn_dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.causal = causal
|
||||||
|
self.register_buffer("mask", None, persistent=False)
|
||||||
|
|
||||||
|
self.use_flash = use_flash
|
||||||
|
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||||
|
|
||||||
|
# determine efficient attention configs for cuda and cpu
|
||||||
|
self.config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
||||||
|
self.cpu_config = self.config(True, True, True)
|
||||||
|
self.cuda_config = None
|
||||||
|
|
||||||
|
if not torch.cuda.is_available() or not use_flash:
|
||||||
|
return
|
||||||
|
|
||||||
|
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||||
|
|
||||||
|
if device_properties.major == 8 and device_properties.minor == 0:
|
||||||
|
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
|
||||||
|
self.cuda_config = self.config(True, False, False)
|
||||||
|
else:
|
||||||
|
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
|
||||||
|
self.cuda_config = self.config(False, True, True)
|
||||||
|
|
||||||
|
def get_mask(self, n, device):
|
||||||
|
if exists(self.mask) and self.mask.shape[-1] >= n:
|
||||||
|
return self.mask[:n, :n]
|
||||||
|
|
||||||
|
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
||||||
|
self.register_buffer("mask", mask, persistent=False)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def flash_attn(self, q, k, v, mask = None):
|
||||||
|
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
||||||
|
|
||||||
|
# Recommended for multi-query single-key-value attention by Tri Dao
|
||||||
|
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
||||||
|
|
||||||
|
if k.ndim == 3:
|
||||||
|
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
|
||||||
|
|
||||||
|
if v.ndim == 3:
|
||||||
|
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
|
||||||
|
|
||||||
|
# Check if mask exists and expand to compatible shape
|
||||||
|
# The mask is B L, so it would have to be expanded to B H N L
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||||
|
mask = mask.expand(-1, heads, q_len, -1)
|
||||||
|
|
||||||
|
# Check if there is a compatible device for flash attention
|
||||||
|
|
||||||
|
config = self.cuda_config if is_cuda else self.cpu_config
|
||||||
|
|
||||||
|
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
||||||
|
|
||||||
|
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
||||||
|
out = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask = mask,
|
||||||
|
dropout_p = self.dropout if self.training else 0.,
|
||||||
|
is_causal = self.causal
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, q, k, v, mask = None):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
n, i, j - sequence length (base sequence length, source, target)
|
||||||
|
d - feature dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
n, device = q.shape[-2], q.device
|
||||||
|
|
||||||
|
scale = q.shape[-1] ** -0.5
|
||||||
|
|
||||||
|
if self.use_flash:
|
||||||
|
return self.flash_attn(q, k, v, mask = mask)
|
||||||
|
|
||||||
|
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
||||||
|
|
||||||
|
# similarity
|
||||||
|
|
||||||
|
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
||||||
|
|
||||||
|
# key padding mask
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||||
|
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
||||||
|
|
||||||
|
# causal mask
|
||||||
|
|
||||||
|
if self.causal:
|
||||||
|
causal_mask = self.get_mask(n, device)
|
||||||
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
|
attn = self.attn_dropout(attn)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
|
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def Sequential(*mods):
|
||||||
|
return nn.Sequential(*filter(exists, mods))
|
||||||
|
|
||||||
|
|
||||||
|
def exists(x):
|
||||||
|
return x is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val, d):
|
||||||
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if callable(d) else d
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, scale=True, dim_cond=None):
|
||||||
|
super().__init__()
|
||||||
|
self.cond = exists(dim_cond)
|
||||||
|
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
||||||
|
|
||||||
|
self.scale = dim**0.5
|
||||||
|
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
||||||
|
|
||||||
|
def forward(self, x, cond=None):
|
||||||
|
gamma = default(self.gamma, 1)
|
||||||
|
out = F.normalize(x, dim=-1) * self.scale * gamma
|
||||||
|
|
||||||
|
if not self.cond:
|
||||||
|
return out
|
||||||
|
|
||||||
|
assert exists(cond)
|
||||||
|
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
||||||
|
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
||||||
|
return out * gamma + beta
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(nn.Conv1d):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
(kernel_size,) = self.kernel_size
|
||||||
|
(dilation,) = self.dilation
|
||||||
|
(stride,) = self.stride
|
||||||
|
|
||||||
|
assert stride == 1
|
||||||
|
self.causal_padding = dilation * (kernel_size - 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||||
|
return super().forward(causal_padded_x)
|
||||||
|
|
||||||
|
|
||||||
|
class GEGLU(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = x.chunk(2, dim=-1)
|
||||||
|
return F.gelu(gate) * x
|
||||||
|
|
||||||
|
|
||||||
|
def FeedForward(dim, mult=4, causal_conv=False):
|
||||||
|
dim_inner = int(dim * mult * 2 / 3)
|
||||||
|
|
||||||
|
conv = None
|
||||||
|
if causal_conv:
|
||||||
|
conv = nn.Sequential(
|
||||||
|
Rearrange("b n d -> b d n"),
|
||||||
|
CausalConv1d(dim_inner, dim_inner, 3),
|
||||||
|
Rearrange("b d n -> b n d"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim))
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverResampler(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth=2,
|
||||||
|
dim_context=None,
|
||||||
|
num_latents=32,
|
||||||
|
dim_head=64,
|
||||||
|
heads=8,
|
||||||
|
ff_mult=4,
|
||||||
|
use_flash_attn=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
dim_context = default(dim_context, dim)
|
||||||
|
|
||||||
|
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
||||||
|
|
||||||
|
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
||||||
|
nn.init.normal_(self.latents, std=0.02)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
Attention(
|
||||||
|
dim=dim,
|
||||||
|
dim_head=dim_head,
|
||||||
|
heads=heads,
|
||||||
|
use_flash=use_flash_attn,
|
||||||
|
cross_attn_include_queries=True,
|
||||||
|
),
|
||||||
|
FeedForward(dim=dim, mult=ff_mult),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
batch = x.shape[0]
|
||||||
|
|
||||||
|
x = self.proj_context(x)
|
||||||
|
|
||||||
|
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
||||||
|
|
||||||
|
for attn, ff in self.layers:
|
||||||
|
latents = attn(latents, x, mask=mask) + latents
|
||||||
|
latents = ff(latents) + latents
|
||||||
|
|
||||||
|
return self.norm(latents)
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
*,
|
||||||
|
dim_context=None,
|
||||||
|
causal=False,
|
||||||
|
dim_head=64,
|
||||||
|
heads=8,
|
||||||
|
dropout=0.0,
|
||||||
|
use_flash=False,
|
||||||
|
cross_attn_include_queries=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head**-0.5
|
||||||
|
self.heads = heads
|
||||||
|
self.cross_attn_include_queries = cross_attn_include_queries
|
||||||
|
|
||||||
|
dim_inner = dim_head * heads
|
||||||
|
dim_context = default(dim_context, dim)
|
||||||
|
|
||||||
|
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
||||||
|
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
||||||
|
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x, context=None, mask=None):
|
||||||
|
h, has_context = self.heads, exists(context)
|
||||||
|
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
if has_context and self.cross_attn_include_queries:
|
||||||
|
context = torch.cat((x, context), dim=-2)
|
||||||
|
|
||||||
|
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
||||||
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||||
|
|
||||||
|
out = self.attend(q, k, v, mask=mask)
|
||||||
|
|
||||||
|
out = rearrange(out, "b h n d -> b n (h d)")
|
||||||
|
return self.to_out(out)
|
|
@ -142,6 +142,19 @@ class GPTTrainer(BaseTTS):
|
||||||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||||
|
|
||||||
# Mel spectrogram extractor for conditioning
|
# Mel spectrogram extractor for conditioning
|
||||||
|
if self.args.gpt_use_perceiver_resampler:
|
||||||
|
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||||
|
filter_length=2048,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
normalize=False,
|
||||||
|
sampling_rate=config.audio.sample_rate,
|
||||||
|
mel_fmin=0,
|
||||||
|
mel_fmax=8000,
|
||||||
|
n_mel_channels=80,
|
||||||
|
mel_norm_file=self.args.mel_norm_file,
|
||||||
|
)
|
||||||
|
else:
|
||||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||||
filter_length=4096,
|
filter_length=4096,
|
||||||
hop_length=1024,
|
hop_length=1024,
|
||||||
|
|
|
@ -23,7 +23,19 @@ init_stream_support()
|
||||||
|
|
||||||
|
|
||||||
def wav_to_mel_cloning(
|
def wav_to_mel_cloning(
|
||||||
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
wav,
|
||||||
|
mel_norms_file="../experiments/clips_mel_norms.pth",
|
||||||
|
mel_norms=None,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
n_fft=4096,
|
||||||
|
hop_length=1024,
|
||||||
|
win_length=4096,
|
||||||
|
power=2,
|
||||||
|
normalized=False,
|
||||||
|
sample_rate=22050,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
n_mels=80,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
|
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
|
||||||
|
@ -38,15 +50,15 @@ def wav_to_mel_cloning(
|
||||||
torch.Tensor: Mel-spectrogram tensor.
|
torch.Tensor: Mel-spectrogram tensor.
|
||||||
"""
|
"""
|
||||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||||
n_fft=4096,
|
n_fft=n_fft,
|
||||||
hop_length=1024,
|
hop_length=hop_length,
|
||||||
win_length=4096,
|
win_length=win_length,
|
||||||
power=2,
|
power=power,
|
||||||
normalized=False,
|
normalized=normalized,
|
||||||
sample_rate=22050,
|
sample_rate=sample_rate,
|
||||||
f_min=0,
|
f_min=f_min,
|
||||||
f_max=8000,
|
f_max=f_max,
|
||||||
n_mels=80,
|
n_mels=n_mels,
|
||||||
norm="slaney",
|
norm="slaney",
|
||||||
).to(device)
|
).to(device)
|
||||||
wav = wav.to(device)
|
wav = wav.to(device)
|
||||||
|
@ -229,6 +241,7 @@ class XttsArgs(Coqpit):
|
||||||
gpt_num_audio_tokens: int = 8194
|
gpt_num_audio_tokens: int = 8194
|
||||||
gpt_start_audio_token: int = 8192
|
gpt_start_audio_token: int = 8192
|
||||||
gpt_stop_audio_token: int = 8193
|
gpt_stop_audio_token: int = 8193
|
||||||
|
gpt_use_perceiver_resampler: bool = False
|
||||||
|
|
||||||
# Diffusion Decoder params
|
# Diffusion Decoder params
|
||||||
diff_model_channels: int = 1024
|
diff_model_channels: int = 1024
|
||||||
|
@ -304,6 +317,7 @@ class Xtts(BaseTTS):
|
||||||
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
||||||
start_audio_token=self.args.gpt_start_audio_token,
|
start_audio_token=self.args.gpt_start_audio_token,
|
||||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||||
|
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_hifigan:
|
if self.args.use_hifigan:
|
||||||
|
@ -359,7 +373,32 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
||||||
audio_22k = audio_22k[:, : 22050 * length]
|
audio_22k = audio_22k[:, : 22050 * length]
|
||||||
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
|
if self.args.gpt_use_perceiver_resampler:
|
||||||
|
mel = wav_to_mel_cloning(audio_22k,
|
||||||
|
mel_norms=self.mel_stats.cpu(),
|
||||||
|
n_fft=2048,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
power=2,
|
||||||
|
normalized=False,
|
||||||
|
sample_rate=22050,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
n_mels=80
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mel = wav_to_mel_cloning(audio_22k,
|
||||||
|
mel_norms=self.mel_stats.cpu(),
|
||||||
|
n_fft=4096,
|
||||||
|
hop_length=1024,
|
||||||
|
win_length=4096,
|
||||||
|
power=2,
|
||||||
|
normalized=False,
|
||||||
|
sample_rate=22050,
|
||||||
|
f_min=0,
|
||||||
|
f_max=8000,
|
||||||
|
n_mels=80
|
||||||
|
)
|
||||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||||
return cond_latent.transpose(1, 2)
|
return cond_latent.transpose(1, 2)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue