mirror of https://github.com/coqui-ai/TTS.git
Add Perceiver
This commit is contained in:
parent
1fb6c203ab
commit
dff3902ca8
|
@ -11,7 +11,7 @@ from transformers import GPT2Config
|
|||
|
||||
from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel
|
||||
from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder
|
||||
|
||||
from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
@ -105,6 +105,8 @@ class GPT(nn.Module):
|
|||
checkpointing=False,
|
||||
average_conditioning_embeddings=False,
|
||||
label_smoothing=0.0,
|
||||
use_perceiver_resampler=False,
|
||||
perceiver_cond_length_compression=256,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -132,6 +134,8 @@ class GPT(nn.Module):
|
|||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.conditioning_dropout = nn.Dropout1d(0.1)
|
||||
self.average_conditioning_embeddings = average_conditioning_embeddings
|
||||
self.use_perceiver_resampler = use_perceiver_resampler
|
||||
self.perceiver_cond_length_compression = perceiver_cond_length_compression
|
||||
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim)
|
||||
self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
||||
|
@ -165,9 +169,22 @@ class GPT(nn.Module):
|
|||
self.text_head = nn.Linear(model_dim, self.number_text_tokens)
|
||||
self.mel_head = nn.Linear(model_dim, self.num_audio_tokens)
|
||||
|
||||
if self.use_perceiver_resampler:
|
||||
self.conditioning_perceiver = PerceiverResampler(
|
||||
dim=model_dim,
|
||||
depth=2,
|
||||
dim_context=model_dim,
|
||||
num_latents=32,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
use_flash_attn=False,
|
||||
)
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
return {
|
||||
"conditioning_encoder": list(self.conditioning_encoder.parameters()),
|
||||
"conditioning_perceiver": list(self.conditioning_perceiver.parameters()) if self.use_perceiver_resampler else None,
|
||||
"gpt": list(self.gpt.parameters()),
|
||||
"heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
|
||||
}
|
||||
|
@ -250,11 +267,8 @@ class GPT(nn.Module):
|
|||
if attn_mask_text is not None:
|
||||
attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
|
||||
if prompt is not None:
|
||||
if attn_mask_cond is not None:
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
else:
|
||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
|
||||
attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1)
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb,
|
||||
|
@ -318,7 +332,6 @@ class GPT(nn.Module):
|
|||
prompt_len = 3
|
||||
prompt_len = prompt_len * 24 # in frames
|
||||
if prompt_codes.shape[-1] >= prompt_len:
|
||||
new_prompt = []
|
||||
for i in range(prompt_codes.shape[0]):
|
||||
if lengths[i] < prompt_len:
|
||||
start = 0
|
||||
|
@ -340,7 +353,9 @@ class GPT(nn.Module):
|
|||
if not return_latent:
|
||||
if cond_input.ndim == 4:
|
||||
cond_input = cond_input.squeeze(1)
|
||||
conds = self.conditioning_encoder(cond_input)
|
||||
conds = self.conditioning_encoder(cond_input) # (b, d, s)
|
||||
if self.use_perceiver_resampler:
|
||||
conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32)
|
||||
else:
|
||||
# already computed
|
||||
conds = cond_input.unsqueeze(1)
|
||||
|
@ -354,6 +369,7 @@ class GPT(nn.Module):
|
|||
wav_lengths,
|
||||
cond_mels=None,
|
||||
cond_idxs=None,
|
||||
cond_lens=None,
|
||||
cond_latents=None,
|
||||
return_attentions=False,
|
||||
return_latent=False,
|
||||
|
@ -379,6 +395,12 @@ class GPT(nn.Module):
|
|||
max_text_len = text_lengths.max()
|
||||
code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3
|
||||
|
||||
if cond_lens is not None:
|
||||
if self.use_perceiver_resampler:
|
||||
cond_lens = cond_lens // self.perceiver_cond_length_compression
|
||||
else:
|
||||
cond_lens = cond_lens // self.code_stride_len
|
||||
|
||||
if cond_idxs is not None:
|
||||
# recompute cond idxs for mel lengths
|
||||
for idx, l in enumerate(code_lengths):
|
||||
|
@ -450,9 +472,13 @@ class GPT(nn.Module):
|
|||
)
|
||||
|
||||
if cond_idxs is not None:
|
||||
# use masking approach
|
||||
for idx, r in enumerate(cond_idxs):
|
||||
l = r[1] - r[0]
|
||||
attn_mask_cond[idx, l:] = 0.0
|
||||
elif cond_lens is not None:
|
||||
for idx, l in enumerate(cond_lens):
|
||||
attn_mask_cond[idx, l:] = 0.0
|
||||
|
||||
for idx, l in enumerate(text_lengths):
|
||||
attn_mask_text[idx, l + 1 :] = 0.0
|
||||
|
@ -468,6 +494,10 @@ class GPT(nn.Module):
|
|||
|
||||
# Compute speech conditioning input
|
||||
if cond_latents is None:
|
||||
if cond_lens is not None:
|
||||
min_cond_len = torch.min(cond_lens)
|
||||
cond_mels = cond_mels[:, :, :, :min_cond_len]
|
||||
|
||||
cond_latents = self.get_style_emb(cond_mels).transpose(1, 2)
|
||||
|
||||
# Get logits
|
||||
|
@ -483,7 +513,7 @@ class GPT(nn.Module):
|
|||
prompt=cond_latents,
|
||||
get_attns=return_attentions,
|
||||
return_latent=return_latent,
|
||||
attn_mask_cond=attn_mask_cond,
|
||||
attn_mask_cond=attn_mask_cond if not self.use_perceiver_resampler else None,
|
||||
attn_mask_text=attn_mask_text,
|
||||
attn_mask_mel=attn_mask_mel,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,321 @@
|
|||
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from collections import namedtuple
|
||||
from functools import wraps
|
||||
from packaging import version
|
||||
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def once(fn):
|
||||
called = False
|
||||
@wraps(fn)
|
||||
def inner(x):
|
||||
nonlocal called
|
||||
if called:
|
||||
return
|
||||
called = True
|
||||
return fn(x)
|
||||
return inner
|
||||
|
||||
print_once = once(print)
|
||||
|
||||
# main class
|
||||
|
||||
class Attend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
use_flash = False
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
self.attn_dropout = nn.Dropout(dropout)
|
||||
|
||||
self.causal = causal
|
||||
self.register_buffer("mask", None, persistent=False)
|
||||
|
||||
self.use_flash = use_flash
|
||||
assert not (use_flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
|
||||
|
||||
# determine efficient attention configs for cuda and cpu
|
||||
self.config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
|
||||
self.cpu_config = self.config(True, True, True)
|
||||
self.cuda_config = None
|
||||
|
||||
if not torch.cuda.is_available() or not use_flash:
|
||||
return
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
|
||||
|
||||
if device_properties.major == 8 and device_properties.minor == 0:
|
||||
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
|
||||
self.cuda_config = self.config(True, False, False)
|
||||
else:
|
||||
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
|
||||
self.cuda_config = self.config(False, True, True)
|
||||
|
||||
def get_mask(self, n, device):
|
||||
if exists(self.mask) and self.mask.shape[-1] >= n:
|
||||
return self.mask[:n, :n]
|
||||
|
||||
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
return mask
|
||||
|
||||
def flash_attn(self, q, k, v, mask = None):
|
||||
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
|
||||
|
||||
# Recommended for multi-query single-key-value attention by Tri Dao
|
||||
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
|
||||
|
||||
if k.ndim == 3:
|
||||
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
|
||||
|
||||
if v.ndim == 3:
|
||||
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
|
||||
|
||||
# Check if mask exists and expand to compatible shape
|
||||
# The mask is B L, so it would have to be expanded to B H N L
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
mask = mask.expand(-1, heads, q_len, -1)
|
||||
|
||||
# Check if there is a compatible device for flash attention
|
||||
|
||||
config = self.cuda_config if is_cuda else self.cpu_config
|
||||
|
||||
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(**config._asdict()):
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask = mask,
|
||||
dropout_p = self.dropout if self.training else 0.,
|
||||
is_causal = self.causal
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def forward(self, q, k, v, mask = None):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
n, i, j - sequence length (base sequence length, source, target)
|
||||
d - feature dimension
|
||||
"""
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
if self.use_flash:
|
||||
return self.flash_attn(q, k, v, mask = mask)
|
||||
|
||||
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale
|
||||
|
||||
# key padding mask
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
|
||||
|
||||
# causal mask
|
||||
|
||||
if self.causal:
|
||||
causal_mask = self.get_mask(n, device)
|
||||
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
||||
|
||||
# attention
|
||||
|
||||
attn = sim.softmax(dim=-1)
|
||||
attn = self.attn_dropout(attn)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v)
|
||||
|
||||
return out
|
||||
|
||||
def Sequential(*mods):
|
||||
return nn.Sequential(*filter(exists, mods))
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, scale=True, dim_cond=None):
|
||||
super().__init__()
|
||||
self.cond = exists(dim_cond)
|
||||
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None
|
||||
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
gamma = default(self.gamma, 1)
|
||||
out = F.normalize(x, dim=-1) * self.scale * gamma
|
||||
|
||||
if not self.cond:
|
||||
return out
|
||||
|
||||
assert exists(cond)
|
||||
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1)
|
||||
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta))
|
||||
return out * gamma + beta
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
(kernel_size,) = self.kernel_size
|
||||
(dilation,) = self.dilation
|
||||
(stride,) = self.stride
|
||||
|
||||
assert stride == 1
|
||||
self.causal_padding = dilation * (kernel_size - 1)
|
||||
|
||||
def forward(self, x):
|
||||
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
return super().forward(causal_padded_x)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.gelu(gate) * x
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, causal_conv=False):
|
||||
dim_inner = int(dim * mult * 2 / 3)
|
||||
|
||||
conv = None
|
||||
if causal_conv:
|
||||
conv = nn.Sequential(
|
||||
Rearrange("b n d -> b d n"),
|
||||
CausalConv1d(dim_inner, dim_inner, 3),
|
||||
Rearrange("b d n -> b n d"),
|
||||
)
|
||||
|
||||
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim))
|
||||
|
||||
|
||||
class PerceiverResampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth=2,
|
||||
dim_context=None,
|
||||
num_latents=32,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
use_flash_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
dim_context = default(dim_context, dim)
|
||||
|
||||
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity()
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
||||
nn.init.normal_(self.latents, std=0.02)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
Attention(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
use_flash=use_flash_attn,
|
||||
cross_attn_include_queries=True,
|
||||
),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
batch = x.shape[0]
|
||||
|
||||
x = self.proj_context(x)
|
||||
|
||||
latents = repeat(self.latents, "n d -> b n d", b=batch)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(latents, x, mask=mask) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
return self.norm(latents)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
dim_context=None,
|
||||
causal=False,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
dropout=0.0,
|
||||
use_flash=False,
|
||||
cross_attn_include_queries=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.cross_attn_include_queries = cross_attn_include_queries
|
||||
|
||||
dim_inner = dim_head * heads
|
||||
dim_context = default(dim_context, dim)
|
||||
|
||||
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash)
|
||||
self.to_q = nn.Linear(dim, dim_inner, bias=False)
|
||||
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False)
|
||||
self.to_out = nn.Linear(dim_inner, dim, bias=False)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h, has_context = self.heads, exists(context)
|
||||
|
||||
context = default(context, x)
|
||||
|
||||
if has_context and self.cross_attn_include_queries:
|
||||
context = torch.cat((x, context), dim=-2)
|
||||
|
||||
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
out = self.attend(q, k, v, mask=mask)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
return self.to_out(out)
|
|
@ -142,17 +142,30 @@ class GPTTrainer(BaseTTS):
|
|||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||
|
||||
# Mel spectrogram extractor for conditioning
|
||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||
filter_length=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
normalize=False,
|
||||
sampling_rate=config.audio.sample_rate,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
n_mel_channels=80,
|
||||
mel_norm_file=self.args.mel_norm_file,
|
||||
)
|
||||
if self.args.gpt_use_perceiver_resampler:
|
||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||
filter_length=2048,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
normalize=False,
|
||||
sampling_rate=config.audio.sample_rate,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
n_mel_channels=80,
|
||||
mel_norm_file=self.args.mel_norm_file,
|
||||
)
|
||||
else:
|
||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||
filter_length=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
normalize=False,
|
||||
sampling_rate=config.audio.sample_rate,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
n_mel_channels=80,
|
||||
mel_norm_file=self.args.mel_norm_file,
|
||||
)
|
||||
|
||||
# Load DVAE
|
||||
self.dvae = DiscreteVAE(
|
||||
|
|
|
@ -23,7 +23,19 @@ init_stream_support()
|
|||
|
||||
|
||||
def wav_to_mel_cloning(
|
||||
wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu")
|
||||
wav,
|
||||
mel_norms_file="../experiments/clips_mel_norms.pth",
|
||||
mel_norms=None,
|
||||
device=torch.device("cpu"),
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
):
|
||||
"""
|
||||
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
|
||||
|
@ -38,15 +50,15 @@ def wav_to_mel_cloning(
|
|||
torch.Tensor: Mel-spectrogram tensor.
|
||||
"""
|
||||
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
power=power,
|
||||
normalized=normalized,
|
||||
sample_rate=sample_rate,
|
||||
f_min=f_min,
|
||||
f_max=f_max,
|
||||
n_mels=n_mels,
|
||||
norm="slaney",
|
||||
).to(device)
|
||||
wav = wav.to(device)
|
||||
|
@ -229,6 +241,7 @@ class XttsArgs(Coqpit):
|
|||
gpt_num_audio_tokens: int = 8194
|
||||
gpt_start_audio_token: int = 8192
|
||||
gpt_stop_audio_token: int = 8193
|
||||
gpt_use_perceiver_resampler: bool = False
|
||||
|
||||
# Diffusion Decoder params
|
||||
diff_model_channels: int = 1024
|
||||
|
@ -304,6 +317,7 @@ class Xtts(BaseTTS):
|
|||
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
||||
start_audio_token=self.args.gpt_start_audio_token,
|
||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
|
||||
)
|
||||
|
||||
if self.args.use_hifigan:
|
||||
|
@ -359,7 +373,32 @@ class Xtts(BaseTTS):
|
|||
|
||||
audio_22k = torchaudio.functional.resample(audio, sr, 22050)
|
||||
audio_22k = audio_22k[:, : 22050 * length]
|
||||
mel = wav_to_mel_cloning(audio_22k, mel_norms=self.mel_stats.cpu())
|
||||
if self.args.gpt_use_perceiver_resampler:
|
||||
mel = wav_to_mel_cloning(audio_22k,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=2048,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80
|
||||
)
|
||||
else:
|
||||
mel = wav_to_mel_cloning(audio_22k,
|
||||
mel_norms=self.mel_stats.cpu(),
|
||||
n_fft=4096,
|
||||
hop_length=1024,
|
||||
win_length=4096,
|
||||
power=2,
|
||||
normalized=False,
|
||||
sample_rate=22050,
|
||||
f_min=0,
|
||||
f_max=8000,
|
||||
n_mels=80
|
||||
)
|
||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||
return cond_latent.transpose(1, 2)
|
||||
|
||||
|
|
Loading…
Reference in New Issue