Add Perceiver

This commit is contained in:
Edresson Casanova 2023-11-01 12:18:35 -03:00 committed by Eren G??lge
parent 1fb6c203ab
commit dff3902ca8
4 changed files with 434 additions and 31 deletions

View File

@ -11,7 +11,7 @@ from transformers import GPT2Config
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
@ -105,6 +105,8 @@ class GPT(nn.Module):
checkpointing=False,
average_conditioning_embeddings=False,
label_smoothing=0.0,
use_perceiver_resampler=False,
perceiver_cond_length_compression=256,
):
"""
Args:
@ -132,6 +134,8 @@ class GPT(nn.Module):
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.conditioning_dropout = nn.Dropout1d(0.1)
self.average_conditioning_embeddings = average_conditioning_embeddings
self.use_perceiver_resampler = use_perceiver_resampler
self.perceiver_cond_length_compression = perceiver_cond_length_compression
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
@ -165,9 +169,22 @@ class GPT(nn.Module):
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
if self.use_perceiver_resampler:
self.conditioning_perceiver = PerceiverResampler(
dim=model_dim,
depth=2,
dim_context=model_dim,
num_latents=32,
dim_head=64,
heads=8,
ff_mult=4,
use_flash_attn=False,
)
def get_grad_norm_parameter_groups(self):
return {
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
"conditioning_perceiver": list(self.conditioning_perceiver.parameters()) if self.use_perceiver_resampler else None,
"gpt": list(self.gpt.parameters()),
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
}
@ -250,11 +267,8 @@ class GPT(nn.Module):
if attn_mask_text is not None:
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
if prompt is not None:
if attn_mask_cond is not None:
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
else:
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
gpt_out = self.gpt(
inputs_embeds=emb,
@ -318,7 +332,6 @@ class GPT(nn.Module):
prompt_len = 3
prompt_len = prompt_len * 24 # in frames
if prompt_codes.shape[-1] >= prompt_len:
new_prompt = []
for i in range(prompt_codes.shape[0]):
if lengths[i] < prompt_len:
start = 0
@ -340,7 +353,9 @@ class GPT(nn.Module):
if not return_latent:
if cond_input.ndim == 4:
cond_input = cond_input.squeeze(1)
conds = self.conditioning_encoder(cond_input)
conds = self.conditioning_encoder(cond_input) # (b, d, s)
if self.use_perceiver_resampler:
conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32)
else:
# already computed
conds = cond_input.unsqueeze(1)
@ -354,6 +369,7 @@ class GPT(nn.Module):
wav_lengths,
cond_mels=None,
cond_idxs=None,
cond_lens=None,
cond_latents=None,
return_attentions=False,
return_latent=False,
@ -379,6 +395,12 @@ class GPT(nn.Module):
max_text_len = text_lengths.max()
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
if cond_lens is not None:
if self.use_perceiver_resampler:
cond_lens = cond_lens // self.perceiver_cond_length_compression
else:
cond_lens = cond_lens // self.code_stride_len
if cond_idxs is not None:
# recompute cond idxs for mel lengths
for idx, l in enumerate(code_lengths):
@ -450,9 +472,13 @@ class GPT(nn.Module):
)
if cond_idxs is not None:
# use masking approach
for idx, r in enumerate(cond_idxs):
l = r[1] - r[0]
attn_mask_cond[idx, l:] = 0.0
elif cond_lens is not None:
for idx, l in enumerate(cond_lens):
attn_mask_cond[idx, l:] = 0.0
for idx, l in enumerate(text_lengths):
attn_mask_text[idx, l + 1 :] = 0.0
@ -468,6 +494,10 @@ class GPT(nn.Module):
# Compute speech conditioning input
if cond_latents is None:
if cond_lens is not None:
min_cond_len = torch.min(cond_lens)
cond_mels = cond_mels[:, :, :, :min_cond_len]
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
# Get logits
@ -483,7 +513,7 @@ class GPT(nn.Module):
prompt=cond_latents,
get_attns=return_attentions,
return_latent=return_latent,
attn_mask_cond=attn_mask_cond,
attn_mask_cond=attn_mask_cond if not self.use_perceiver_resampler else None,
attn_mask_text=attn_mask_text,
attn_mask_mel=attn_mask_mel,
)

View File

@ -0,0 +1,321 @@
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
import torch
from torch import nn, einsum
import torch.nn.functional as F
from collections import namedtuple
from functools import wraps
from packaging import version
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def exists(val):
return val is not None
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
# main class
class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
causal = False,
use_flash = False
):
super().__init__()
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.causal = causal
self.register_buffer("mask", None, persistent=False)
self.use_flash = use_flash
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
# determine efficient attention configs for cuda and cpu
self.config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
self.cpu_config = self.config(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = self.config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = self.config(False, True, True)
def get_mask(self, n, device):
if exists(self.mask) and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def flash_attn(self, q, k, v, mask = None):
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
if k.ndim == 3:
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
if v.ndim == 3:
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
mask = mask.expand(-1, heads, q_len, -1)
# Check if there is a compatible device for flash attention
config = self.cuda_config if is_cuda else self.cpu_config
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
return out
def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""
n, device = q.shape[-2], q.device
scale = q.shape[-1] ** -0.5
if self.use_flash:
return self.flash_attn(q, k, v, mask = mask)
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
# similarity
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
# key padding mask
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
# causal mask
if self.causal:
causal_mask = self.get_mask(n, device)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
# attention
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
# aggregate values
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
return out
def Sequential(*mods):
return nn.Sequential(*filter(exists, mods))
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
class RMSNorm(nn.Module):
def __init__(self, dim, scale=True, dim_cond=None):
super().__init__()
self.cond = exists(dim_cond)
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
def forward(self, x, cond=None):
gamma = default(self.gamma, 1)
out = F.normalize(x, dim=-1) * self.scale * gamma
if not self.cond:
return out
assert exists(cond)
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
return out * gamma + beta
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
(kernel_size,) = self.kernel_size
(dilation,) = self.dilation
(stride,) = self.stride
assert stride == 1
self.causal_padding = dilation * (kernel_size - 1)
def forward(self, x):
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
return super().forward(causal_padded_x)
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.gelu(gate) * x
def FeedForward(dim, mult=4, causal_conv=False):
dim_inner = int(dim * mult * 2 / 3)
conv = None
if causal_conv:
conv = nn.Sequential(
Rearrange("b n d -> b d n"),
CausalConv1d(dim_inner, dim_inner, 3),
Rearrange("b d n -> b n d"),
)
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim))
class PerceiverResampler(nn.Module):
def __init__(
self,
*,
dim,
depth=2,
dim_context=None,
num_latents=32,
dim_head=64,
heads=8,
ff_mult=4,
use_flash_attn=False,
):
super().__init__()
dim_context = default(dim_context, dim)
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
self.latents = nn.Parameter(torch.randn(num_latents, dim))
nn.init.normal_(self.latents, std=0.02)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
use_flash=use_flash_attn,
cross_attn_include_queries=True,
),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
self.norm = RMSNorm(dim)
def forward(self, x, mask=None):
batch = x.shape[0]
x = self.proj_context(x)
latents = repeat(self.latents, "n d -> b n d", b=batch)
for attn, ff in self.layers:
latents = attn(latents, x, mask=mask) + latents
latents = ff(latents) + latents
return self.norm(latents)
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_context=None,
causal=False,
dim_head=64,
heads=8,
dropout=0.0,
use_flash=False,
cross_attn_include_queries=False,
):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
self.cross_attn_include_queries = cross_attn_include_queries
dim_inner = dim_head * heads
dim_context = default(dim_context, dim)
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
self.to_q = nn.Linear(dim, dim_inner, bias=False)
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
self.to_out = nn.Linear(dim_inner, dim, bias=False)
def forward(self, x, context=None, mask=None):
h, has_context = self.heads, exists(context)
context = default(context, x)
if has_context and self.cross_attn_include_queries:
context = torch.cat((x, context), dim=-2)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = self.attend(q, k, v, mask=mask)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)

View File

@ -142,17 +142,30 @@ class GPTTrainer(BaseTTS):
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
# Mel spectrogram extractor for conditioning
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
filter_length=4096,
hop_length=1024,
win_length=4096,
normalize=False,
sampling_rate=config.audio.sample_rate,
mel_fmin=0,
mel_fmax=8000,
n_mel_channels=80,
mel_norm_file=self.args.mel_norm_file,
)
if self.args.gpt_use_perceiver_resampler:
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
filter_length=2048,
hop_length=256,
win_length=1024,
normalize=False,
sampling_rate=config.audio.sample_rate,
mel_fmin=0,
mel_fmax=8000,
n_mel_channels=80,
mel_norm_file=self.args.mel_norm_file,
)
else:
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
filter_length=4096,
hop_length=1024,
win_length=4096,
normalize=False,
sampling_rate=config.audio.sample_rate,
mel_fmin=0,
mel_fmax=8000,
n_mel_channels=80,
mel_norm_file=self.args.mel_norm_file,
)
# Load DVAE
self.dvae = DiscreteVAE(

View File

@ -23,7 +23,19 @@ init_stream_support()
def wav_to_mel_cloning(
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
wav,
mel_norms_file="../experiments/clips_mel_norms.pth",
mel_norms=None,
device=torch.device("cpu"),
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
):
"""
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
@ -38,15 +50,15 @@ def wav_to_mel_cloning(
torch.Tensor: Mel-spectrogram tensor.
"""
mel_stft = torchaudio.transforms.MelSpectrogram(
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
power=power,
normalized=normalized,
sample_rate=sample_rate,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
norm="slaney",
).to(device)
wav = wav.to(device)
@ -229,6 +241,7 @@ class XttsArgs(Coqpit):
gpt_num_audio_tokens: int = 8194
gpt_start_audio_token: int = 8192
gpt_stop_audio_token: int = 8193
gpt_use_perceiver_resampler: bool = False
# Diffusion Decoder params
diff_model_channels: int = 1024
@ -304,6 +317,7 @@ class Xtts(BaseTTS):
num_audio_tokens=self.args.gpt_num_audio_tokens,
start_audio_token=self.args.gpt_start_audio_token,
stop_audio_token=self.args.gpt_stop_audio_token,
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
)
if self.args.use_hifigan:
@ -359,7 +373,32 @@ class Xtts(BaseTTS):
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
audio_22k = audio_22k[:, : 22050 * length]
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
if self.args.gpt_use_perceiver_resampler:
mel = wav_to_mel_cloning(audio_22k,
mel_norms=self.mel_stats.cpu(),
n_fft=2048,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80
)
else:
mel = wav_to_mel_cloning(audio_22k,
mel_norms=self.mel_stats.cpu(),
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80
)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)