style fix

This commit is contained in:
manmay-nakhashi 2023-04-22 17:49:47 +05:30
parent b892aa925a
commit 18c745ceef
22 changed files with 723 additions and 1605 deletions

View File

@ -1,9 +1,8 @@
import os
import functools import functools
import math import math
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
@ -11,6 +10,7 @@ from transformers import LogitsWarper
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
def zero_module(module): def zero_module(module):
""" """
Zero out the parameters of a module and return it. Zero out the parameters of a module and return it.
@ -64,11 +64,11 @@ class QKVAttentionLegacy(nn.Module):
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum( weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
if rel_pos is not None: if rel_pos is not None:
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None: if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
@ -112,7 +112,13 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings: if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) self.relative_pos_embeddings = RelativePositionBias(
scale=(channels // self.num_heads) ** 0.5,
causal=False,
heads=num_heads,
num_buckets=32,
max_distance=64,
)
else: else:
self.relative_pos_embeddings = None self.relative_pos_embeddings = None
@ -168,9 +174,7 @@ class Downsample(nn.Module):
stride = factor stride = factor
if use_conv: if use_conv:
self.op = nn.Conv1d( self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad)
self.channels, self.out_channels, ksize, stride=stride, padding=pad
)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
@ -221,17 +225,13 @@ class ResBlock(nn.Module):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = nn.Conv1d( self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding)
channels, self.out_channels, kernel_size, padding=padding
)
else: else:
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
@ -249,7 +249,8 @@ class ResBlock(nn.Module):
class AudioMiniEncoder(nn.Module): class AudioMiniEncoder(nn.Module):
def __init__(self, def __init__(
self,
spec_dim, spec_dim,
embedding_dim, embedding_dim,
base_channels=128, base_channels=128,
@ -259,27 +260,27 @@ class AudioMiniEncoder(nn.Module):
num_attn_heads=4, num_attn_heads=4,
dropout=0, dropout=0,
downsample_factor=2, downsample_factor=2,
kernel_size=3): kernel_size=3,
):
super().__init__() super().__init__()
self.init = nn.Sequential( self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
)
ch = base_channels ch = base_channels
res = [] res = []
for l in range(depth): for l in range(depth):
for r in range(resnet_blocks): for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor)) res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
ch *= 2 ch *= 2
self.res = nn.Sequential(*res) self.res = nn.Sequential(*res)
self.final = nn.Sequential( self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
normalization(ch),
nn.SiLU(),
nn.Conv1d(ch, embedding_dim, 1)
)
attn = [] attn = []
for a in range(attn_blocks): for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads,)) attn.append(
AttentionBlock(
embedding_dim,
num_attn_heads,
)
)
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.dim = embedding_dim self.dim = embedding_dim
@ -291,15 +292,24 @@ class AudioMiniEncoder(nn.Module):
return h[:, :, 0] return h[:, :, 0]
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../utils/assets/tortoise/mel_norms.pth') DEFAULT_MEL_NORM_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/mel_norms.pth"
)
class TorchMelSpectrogram(nn.Module): class TorchMelSpectrogram(nn.Module):
def __init__(self, filter_length=1024, hop_length=256, def __init__(
win_length=1024, n_mel_channels=80, self,
mel_fmin=0, mel_fmax=8000, filter_length=1024,
sampling_rate=22050, normalize=False, hop_length=256,
mel_norm_file=DEFAULT_MEL_NORM_FILE): win_length=1024,
n_mel_channels=80,
mel_fmin=0,
mel_fmax=8000,
sampling_rate=22050,
normalize=False,
mel_norm_file=DEFAULT_MEL_NORM_FILE,
):
super().__init__() super().__init__()
# These are the default tacotron values for the MEL spectrogram. # These are the default tacotron values for the MEL spectrogram.
self.filter_length = filter_length self.filter_length = filter_length
@ -309,7 +319,8 @@ class TorchMelSpectrogram(nn.Module):
self.mel_fmin = mel_fmin self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax self.mel_fmax = mel_fmax
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, self.mel_stft = torchaudio.transforms.MelSpectrogram(
n_fft=self.filter_length,
hop_length=self.hop_length, hop_length=self.hop_length,
win_length=self.win_length, win_length=self.win_length,
power=2, power=2,
@ -318,7 +329,8 @@ class TorchMelSpectrogram(nn.Module):
f_min=self.mel_fmin, f_min=self.mel_fmin,
f_max=self.mel_fmax, f_max=self.mel_fmax,
n_mels=self.n_mel_channels, n_mels=self.n_mel_channels,
norm="slaney") norm="slaney",
)
self.mel_norm_file = mel_norm_file self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None: if self.mel_norm_file is not None:
self.mel_norms = torch.load(self.mel_norm_file) self.mel_norms = torch.load(self.mel_norm_file)
@ -326,7 +338,9 @@ class TorchMelSpectrogram(nn.Module):
self.mel_norms = None self.mel_norms = None
def forward(self, inp): def forward(self, inp):
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) if (
len(inp.shape) == 3
): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1) inp = inp.squeeze(1)
assert len(inp.shape) == 2 assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device) self.mel_stft = self.mel_stft.to(inp.device)
@ -344,6 +358,7 @@ class CheckpointedLayer(nn.Module):
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args. checkpoint for all other args.
""" """
def __init__(self, wrap): def __init__(self, wrap):
super().__init__() super().__init__()
self.wrap = wrap self.wrap = wrap
@ -360,6 +375,7 @@ class CheckpointedXTransformerEncoder(nn.Module):
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects. to channels-last that XTransformer expects.
""" """
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
super().__init__() super().__init__()
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
@ -374,10 +390,10 @@ class CheckpointedXTransformerEncoder(nn.Module):
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
if self.needs_permute: if self.needs_permute:
x = x.permute(0,2,1) x = x.permute(0, 2, 1)
h = self.transformer(x, **kwargs) h = self.transformer(x, **kwargs)
if self.exit_permute: if self.exit_permute:
h = h.permute(0,2,1) h = h.permute(0, 2, 1)
return h return h
@ -392,9 +408,7 @@ class TypicalLogitsWarper(LogitsWarper):
self.mass = mass self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__( def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
# calculate entropy # calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1) normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized) p = torch.exp(normalized)
@ -409,15 +423,11 @@ class TypicalLogitsWarper(LogitsWarper):
# Remove tokens with cumulative mass above the threshold # Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1) last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0 last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather( sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1: if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, self.filter_value) scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores return scores

View File

@ -7,11 +7,10 @@ import numpy as np
import torch import torch
import torchaudio import torchaudio
from scipy.io.wavfile import read from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
BUILTIN_VOICES_DIR = os.path.join( BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices")
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices"
)
def load_wav_to_torch(full_path): def load_wav_to_torch(full_path):
@ -58,10 +57,7 @@ def read_audio_file(audiopath: str):
def load_required_audio(audiopath: str): def load_required_audio(audiopath: str):
audio, lsr = read_audio_file(audiopath) audio, lsr = read_audio_file(audiopath)
audios = [ audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)]
torchaudio.functional.resample(audio, lsr, sampling_rate)
for sampling_rate in (22050, 24000)
]
for audio in audios: for audio in audios:
check_audio(audio, audiopath) check_audio(audio, audiopath)
@ -83,9 +79,7 @@ TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(norm_mel): def denormalize_tacotron_mel(norm_mel):
return ((norm_mel + 1) / 2) * ( return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN
TACOTRON_MEL_MAX - TACOTRON_MEL_MIN
) + TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel): def normalize_tacotron_mel(mel):
@ -118,11 +112,7 @@ def get_voices(extra_voice_dirs: List[str] = []):
for sub in subs: for sub in subs:
subj = os.path.join(d, sub) subj = os.path.join(d, sub)
if os.path.isdir(subj): if os.path.isdir(subj):
voices[sub] = ( voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth"))
list(glob(f"{subj}/*.wav"))
+ list(glob(f"{subj}/*.mp3"))
+ list(glob(f"{subj}/*.pth"))
)
return voices return voices
@ -148,9 +138,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
for voice in voices: for voice in voices:
if voice == "random": if voice == "random":
if len(voices) > 1: if len(voices) > 1:
print( print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
"Cannot combine a random voice with a non-random voice. Just using a random voice."
)
return None, None return None, None
clip, latent = load_voice(voice, extra_voice_dirs) clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None: if latent is None:
@ -171,15 +159,18 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
latents = (latents_0, latents_1) latents = (latents_0, latents_1)
return None, latents return None, latents
def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"): def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
stft = TorchSTFT(n_fft=1024, stft = TorchSTFT(
n_fft=1024,
hop_length=256, hop_length=256,
win_length=1024, win_length=1024,
use_mel=True, use_mel=True,
n_mels=100, n_mels=100,
sample_rate=24000, sample_rate=24000,
mel_fmin=0, mel_fmin=0,
mel_fmax=12000) mel_fmax=12000,
)
stft = stft.to(device) stft = stft.to(device)
mel = stft(wav) mel = stft(wav)
mel = dynamic_range_compression(mel) mel = dynamic_range_compression(mel)

View File

@ -9,6 +9,7 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
def null_position_embeddings(range, dim): def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
@ -98,9 +99,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
assert self.cached_mel_emb is not None assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model. assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model. assert labels is None # Training not supported by this inference model.
return_dict = ( return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict if return_dict is not None else self.config.use_return_dict
)
# Create embedding # Create embedding
mel_len = self.cached_mel_emb.shape[1] mel_len = self.cached_mel_emb.shape[1]
@ -109,9 +108,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
text_emb = self.embeddings(text_inputs) text_emb = self.embeddings(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb) text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]: if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave( mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0)
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
)
else: # this outcome only occurs once per loop in most cases else: # this outcome only occurs once per loop in most cases
mel_emb = self.cached_mel_emb mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1) emb = torch.cat([mel_emb, text_emb], dim=1)
@ -158,10 +155,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
""" """
return tuple( return tuple(
tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past for layer_past in past
) )
@ -210,9 +204,7 @@ class LearnedPositionEmbeddings(nn.Module):
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
def build_hf_gpt_transformer( def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing
):
""" """
GPT-2 implemented by the HuggingFace library. GPT-2 implemented by the HuggingFace library.
""" """
@ -230,9 +222,7 @@ def build_hf_gpt_transformer(
) )
gpt = GPT2Model(gpt_config) gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings # Override the built in positional embeddings
del ( del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
gpt.wpe
) # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused. # Built-in token embeddings are unused.
del gpt.wte del gpt.wte
@ -251,21 +241,15 @@ class MelEncoder(nn.Module):
self.channels = channels self.channels = channels
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
nn.Sequential( nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]
),
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 16, channels // 2), nn.GroupNorm(channels // 16, channels // 2),
nn.ReLU(), nn.ReLU(),
nn.Sequential( nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]
),
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 8, channels), nn.GroupNorm(channels // 8, channels),
nn.ReLU(), nn.ReLU(),
nn.Sequential( nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]
),
) )
self.reduction = 4 self.reduction = 4
@ -317,9 +301,7 @@ class UnifiedVoice(nn.Module):
super().__init__() super().__init__()
self.number_text_tokens = number_text_tokens self.number_text_tokens = number_text_tokens
self.start_text_token = ( self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
number_text_tokens * types if start_text_token is None else start_text_token
)
self.stop_text_token = 0 self.stop_text_token = 0
self.number_mel_codes = number_mel_codes self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token self.start_mel_token = start_mel_token
@ -331,12 +313,8 @@ class UnifiedVoice(nn.Module):
self.model_dim = model_dim self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder( self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
80, model_dim, num_attn_heads=heads self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
)
self.text_embedding = nn.Embedding(
self.number_text_tokens * types + 1, model_dim
)
if use_mel_codes_as_input: if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else: else:
@ -356,12 +334,8 @@ class UnifiedVoice(nn.Module):
checkpointing, checkpointing,
) )
if train_solo_embeddings: if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter( self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
torch.randn(1, 1, model_dim) * 0.02, requires_grad=True self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
)
self.text_solo_embedding = nn.Parameter(
torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
)
else: else:
self.mel_solo_embedding = 0 self.mel_solo_embedding = 0
self.text_solo_embedding = 0 self.text_solo_embedding = 0
@ -414,9 +388,7 @@ class UnifiedVoice(nn.Module):
preformatting to create a working TTS model. preformatting to create a working TTS model.
""" """
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>). # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = torch.div( mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc")
wav_lengths, self.mel_length_compression, rounding_mode="trunc"
)
for b in range(len(mel_lengths)): for b in range(len(mel_lengths)):
actual_end = ( actual_end = (
mel_lengths[b] + 1 mel_lengths[b] + 1
@ -436,31 +408,22 @@ class UnifiedVoice(nn.Module):
return_latent=False, return_latent=False,
): ):
if second_inputs is not None: if second_inputs is not None:
emb = torch.cat( emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
[speech_conditioning_inputs, first_inputs, second_inputs], dim=1
)
else: else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt( gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
inputs_embeds=emb, return_dict=True, output_attentions=get_attns
)
if get_attns: if get_attns:
return gpt_out.attentions return gpt_out.attentions
enc = gpt_out.last_hidden_state[ enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
:, 1:
] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc) enc = self.final_norm(enc)
if return_latent: if return_latent:
return ( return (
enc[ enc[
:, :,
speech_conditioning_inputs.shape[ speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1],
1
] : speech_conditioning_inputs.shape[1]
+ first_inputs.shape[1],
], ],
enc[:, -second_inputs.shape[1] :], enc[:, -second_inputs.shape[1] :],
) )
@ -539,9 +502,7 @@ class UnifiedVoice(nn.Module):
text_inputs, text_targets = self.build_aligned_inputs_and_targets( text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token text_inputs, self.start_text_token, self.stop_text_token
) )
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding( text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
text_inputs
)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets( mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
mel_codes, self.start_mel_token, self.stop_mel_token mel_codes, self.start_mel_token, self.stop_mel_token
) )
@ -596,15 +557,13 @@ class UnifiedVoice(nn.Module):
max_generate_length=None, max_generate_length=None,
typical_sampling=False, typical_sampling=False,
typical_mass=0.9, typical_mass=0.9,
**hf_generate_kwargs **hf_generate_kwargs,
): ):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets( text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token text_inputs, self.start_text_token, self.stop_text_token
) )
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding( text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
text_inputs
)
conds = speech_conditioning_latent.unsqueeze(1) conds = speech_conditioning_latent.unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1) emb = torch.cat([conds, text_emb], dim=1)
@ -628,20 +587,14 @@ class UnifiedVoice(nn.Module):
num_return_sequences % input_tokens.shape[0] == 0 num_return_sequences % input_tokens.shape[0] == 0
), "The number of return sequences must be divisible by the number of input sequences" ), "The number of return sequences must be divisible by the number of input sequences"
fake_inputs = fake_inputs.repeat(num_return_sequences, 1) fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
input_tokens = input_tokens.repeat( input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
num_return_sequences // input_tokens.shape[0], 1
)
inputs = torch.cat([fake_inputs, input_tokens], dim=1) inputs = torch.cat([fake_inputs, input_tokens], dim=1)
logits_processor = ( logits_processor = (
LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
if typical_sampling
else LogitsProcessorList()
) # TODO disable this ) # TODO disable this
max_length = ( max_length = (
trunc_index + self.max_mel_tokens - 1 trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
if max_generate_length is None
else trunc_index + max_generate_length
) )
gen = self.inference_model.generate( gen = self.inference_model.generate(
inputs, inputs,
@ -651,7 +604,7 @@ class UnifiedVoice(nn.Module):
max_length=max_length, max_length=max_length,
logits_processor=logits_processor, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
**hf_generate_kwargs **hf_generate_kwargs,
) )
return gen[:, trunc_index:] return gen[:, trunc_index:]

View File

@ -1,13 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from TTS.tts.layers.tortoise.arch_utils import ( from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module
AttentionBlock,
Downsample,
Upsample,
normalization,
zero_module,
)
class ResBlock(nn.Module): class ResBlock(nn.Module):
@ -54,19 +48,13 @@ class ResBlock(nn.Module):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
nn.Conv1d(
self.out_channels, self.out_channels, kernel_size, padding=padding
)
),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = nn.Conv1d( self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding)
dims, channels, self.out_channels, kernel_size, padding=padding
)
else: else:
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
@ -104,24 +92,14 @@ class AudioMiniEncoder(nn.Module):
self.layers = depth self.layers = depth
for l in range(depth): for l in range(depth):
for r in range(resnet_blocks): for r in range(resnet_blocks):
res.append( res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size) res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
)
res.append(
Downsample(
ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
)
)
ch *= 2 ch *= 2
self.res = nn.Sequential(*res) self.res = nn.Sequential(*res)
self.final = nn.Sequential( self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
)
attn = [] attn = []
for a in range(attn_blocks): for a in range(attn_blocks):
attn.append( attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)
)
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.dim = embedding_dim self.dim = embedding_dim

View File

@ -12,9 +12,10 @@ def exists(val):
return val is not None return val is not None
def masked_mean(t, mask, dim = 1): def masked_mean(t, mask, dim=1):
t = t.masked_fill(~mask[:, :, None], 0.) t = t.masked_fill(~mask[:, :, None], 0.0)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] return t.sum(dim=1) / mask.sum(dim=1)[..., None]
class CLVP(nn.Module): class CLVP(nn.Module):
""" """
@ -59,13 +60,14 @@ class CLVP(nn.Module):
dim=dim_text, dim=dim_text,
depth=text_enc_depth, depth=text_enc_depth,
heads=text_heads, heads=text_heads,
ff_dropout=.1, ff_dropout=0.1,
ff_mult=2, ff_mult=2,
attn_dropout=.1, attn_dropout=0.1,
use_rmsnorm=True, use_rmsnorm=True,
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
)) ),
)
self.speech_transformer = CheckpointedXTransformerEncoder( self.speech_transformer = CheckpointedXTransformerEncoder(
needs_permute=False, needs_permute=False,
exit_permute=False, exit_permute=False,
@ -74,20 +76,23 @@ class CLVP(nn.Module):
dim=dim_speech, dim=dim_speech,
depth=speech_enc_depth, depth=speech_enc_depth,
heads=speech_heads, heads=speech_heads,
ff_dropout=.1, ff_dropout=0.1,
ff_mult=2, ff_mult=2,
attn_dropout=.1, attn_dropout=0.1,
use_rmsnorm=True, use_rmsnorm=True,
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
)) ),
)
else: else:
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, self.text_transformer = Transformer(
heads=text_heads) causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, )
depth=speech_enc_depth, heads=speech_heads) self.speech_transformer = Transformer(
causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads
)
self.temperature = nn.Parameter(torch.tensor(1.)) self.temperature = nn.Parameter(torch.tensor(1.0))
self.text_mask_percentage = text_mask_percentage self.text_mask_percentage = text_mask_percentage
self.voice_mask_percentage = voice_mask_percentage self.voice_mask_percentage = voice_mask_percentage
self.wav_token_compression = wav_token_compression self.wav_token_compression = wav_token_compression
@ -96,12 +101,7 @@ class CLVP(nn.Module):
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
def forward( def forward(self, text, speech_tokens, return_loss=False):
self,
text,
speech_tokens,
return_loss=False
):
b, device = text.shape[0], text.device b, device = text.shape[0], text.device
if self.training: if self.training:
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
@ -131,25 +131,29 @@ class CLVP(nn.Module):
temp = self.temperature.exp() temp = self.temperature.exp()
if not return_loss: if not return_loss:
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp
return sim return sim
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp
labels = torch.arange(b, device=device) labels = torch.arange(b, device=device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss return loss
if __name__ == '__main__': if __name__ == "__main__":
clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2) clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2)
clip(torch.randint(0,256,(2,120)), clip(
torch.tensor([50,100]), torch.randint(0, 256, (2, 120)),
torch.randint(0,8192,(2,250)), torch.tensor([50, 100]),
torch.tensor([101,102]), torch.randint(0, 8192, (2, 250)),
return_loss=True) torch.tensor([101, 102]),
nonloss = clip(torch.randint(0,256,(2,120)), return_loss=True,
torch.tensor([50,100]), )
torch.randint(0,8192,(2,250)), nonloss = clip(
torch.tensor([101,102]), torch.randint(0, 256, (2, 120)),
return_loss=False) torch.tensor([50, 100]),
torch.randint(0, 8192, (2, 250)),
torch.tensor([101, 102]),
return_loss=False,
)
print(nonloss.shape) print(nonloss.shape)

View File

@ -17,16 +17,7 @@ def masked_mean(t, mask):
class CollapsingTransformer(nn.Module): class CollapsingTransformer(nn.Module):
def __init__( def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
self,
model_dim,
output_dims,
heads,
dropout,
depth,
mask_percentage=0,
**encoder_kwargs
):
super().__init__() super().__init__()
self.transformer = ContinuousTransformerWrapper( self.transformer = ContinuousTransformerWrapper(
max_seq_len=-1, max_seq_len=-1,
@ -105,9 +96,7 @@ class CVVP(nn.Module):
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
if mel_codes is None: if mel_codes is None:
self.speech_emb = nn.Conv1d( self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
mel_channels, model_dim, kernel_size=5, padding=2
)
else: else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer( self.speech_transformer = CollapsingTransformer(
@ -135,9 +124,7 @@ class CVVP(nn.Module):
enc_speech = self.speech_transformer(speech_emb) enc_speech = self.speech_transformer(speech_emb)
speech_latents = self.to_speech_latent(enc_speech) speech_latents = self.to_speech_latent(enc_speech)
cond_latents, speech_latents = map( cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents))
lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)
)
temp = self.temperature.exp() temp = self.temperature.exp()
if not return_loss: if not return_loss:

View File

@ -13,8 +13,8 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch as th import torch as th
from tqdm import tqdm
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
from tqdm import tqdm
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
@ -38,18 +38,9 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
# Force variances to be Tensors. Broadcasting helps convert scalars to # Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp(). # Tensors, but it does not work for th.exp().
logvar1, logvar2 = [ logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * ( return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x): def approx_standard_normal_cdf(x):
@ -112,9 +103,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
scale = 1000 / num_diffusion_timesteps scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
return np.linspace( return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine": elif schedule_name == "cosine":
return betas_for_alpha_bar( return betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
@ -149,9 +138,9 @@ class ModelMeanType(enum.Enum):
Which type of output the model predicts. Which type of output the model predicts.
""" """
PREVIOUS_X = 'previous_x' # the model predicts x_{t-1} PREVIOUS_X = "previous_x" # the model predicts x_{t-1}
START_X = 'start_x' # the model predicts x_0 START_X = "start_x" # the model predicts x_0
EPSILON = 'epsilon' # the model predicts epsilon EPSILON = "epsilon" # the model predicts epsilon
class ModelVarType(enum.Enum): class ModelVarType(enum.Enum):
@ -162,17 +151,17 @@ class ModelVarType(enum.Enum):
values between FIXED_SMALL and FIXED_LARGE, making its job easier. values between FIXED_SMALL and FIXED_LARGE, making its job easier.
""" """
LEARNED = 'learned' LEARNED = "learned"
FIXED_SMALL = 'fixed_small' FIXED_SMALL = "fixed_small"
FIXED_LARGE = 'fixed_large' FIXED_LARGE = "fixed_large"
LEARNED_RANGE = 'learned_range' LEARNED_RANGE = "learned_range"
class LossType(enum.Enum): class LossType(enum.Enum):
MSE = 'mse' # use raw MSE loss (and KL when learning variances) MSE = "mse" # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances) RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances)
KL = 'kl' # use the variational lower-bound KL = "kl" # use the variational lower-bound
RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB
def is_vb(self): def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL return self == LossType.KL or self == LossType.RESCALED_KL
@ -239,22 +228,12 @@ class GaussianDiffusion:
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = ( self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the # log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain. # beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log( self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
np.append(self.posterior_variance[1], self.posterior_variance[1:]) self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
) self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
""" """
@ -264,13 +243,9 @@ class GaussianDiffusion:
:param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape. :return: A tuple (mean, variance, log_variance), all of x_start's shape.
""" """
mean = ( mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor( log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
@ -289,8 +264,7 @@ class GaussianDiffusion:
assert noise.shape == x_start.shape assert noise.shape == x_start.shape
return ( return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
* noise
) )
def q_posterior_mean_variance(self, x_start, x_t, t): def q_posterior_mean_variance(self, x_start, x_t, t):
@ -306,9 +280,7 @@ class GaussianDiffusion:
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
) )
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor( posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
self.posterior_log_variance_clipped, t, x_t.shape
)
assert ( assert (
posterior_mean.shape[0] posterior_mean.shape[0]
== posterior_variance.shape[0] == posterior_variance.shape[0]
@ -317,9 +289,7 @@ class GaussianDiffusion:
) )
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance( def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
):
""" """
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0. the initial x, x_0.
@ -358,9 +328,7 @@ class GaussianDiffusion:
model_log_variance = model_var_values model_log_variance = model_var_values
model_variance = th.exp(model_log_variance) model_variance = th.exp(model_log_variance)
else: else:
min_log = _extract_into_tensor( min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var]. # The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2 frac = (model_var_values + 1) / 2
@ -398,26 +366,18 @@ class GaussianDiffusion:
return x return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X: if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart( pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X: if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output) pred_xstart = process_xstart(model_output)
else: else:
pred_xstart = process_xstart( pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
else: else:
raise NotImplementedError(self.model_mean_type) raise NotImplementedError(self.model_mean_type)
assert ( assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return { return {
"mean": model_mean, "mean": model_mean,
"variance": model_variance, "variance": model_variance,
@ -436,16 +396,12 @@ class GaussianDiffusion:
assert x_t.shape == xprev.shape assert x_t.shape == xprev.shape
return ( # (xprev - coef2*x_t) / coef1 return ( # (xprev - coef2*x_t) / coef1
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
- _extract_into_tensor( - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
)
* x_t
) )
def _predict_eps_from_xstart(self, x_t, t, pred_xstart): def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return ( return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
- pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _scale_timesteps(self, t): def _scale_timesteps(self, t):
@ -463,9 +419,7 @@ class GaussianDiffusion:
This uses the conditioning strategy from Sohl-Dickstein et al. (2015). This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
""" """
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
new_mean = ( new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
return new_mean return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
@ -481,16 +435,13 @@ class GaussianDiffusion:
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn( eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs)
x, self._scale_timesteps(t), **model_kwargs
)
out = p_mean_var.copy() out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance( out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
x_start=out["pred_xstart"], x_t=x, t=t
)
return out return out
def k_diffusion_sample_loop( def k_diffusion_sample_loop(
self, self,
k_sampler, k_sampler,
@ -512,9 +463,7 @@ class GaussianDiffusion:
def model_split(*args, **kwargs): def model_split(*args, **kwargs):
model_output = model(*args, **kwargs) model_output = model(*args, **kwargs)
model_epsilon, model_var = th.split( model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1)
model_output, model_output.shape[1] // 2, dim=1
)
return model_epsilon, model_var return model_epsilon, model_var
# #
@ -523,9 +472,7 @@ class GaussianDiffusion:
print(th.tensor(self.betas)) print(th.tensor(self.betas))
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas))
""" """
noise_schedule = NoiseScheduleVP( noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4)
schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4
)
def model_fn_prewrap(x, t, *args, **kwargs): def model_fn_prewrap(x, t, *args, **kwargs):
""" """
@ -584,11 +531,10 @@ class GaussianDiffusion:
if self.conditioning_free is not True: if self.conditioning_free is not True:
raise RuntimeError("cond_free must be true") raise RuntimeError("cond_free must be true")
with tqdm(total=self.num_timesteps) as pbar: with tqdm(total=self.num_timesteps) as pbar:
return self.k_diffusion_sample_loop( return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs
)
else: else:
raise RuntimeError("sampler not impl") raise RuntimeError("sampler not impl")
def p_sample( def p_sample(
self, self,
model, model,
@ -625,13 +571,9 @@ class GaussianDiffusion:
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) )
noise = th.randn_like(x) noise = th.randn_like(x)
nonzero_mask = ( nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None: if cond_fn is not None:
out["mean"] = self.condition_mean( out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]} return {"sample": sample, "pred_xstart": out["pred_xstart"]}
@ -758,20 +700,11 @@ class GaussianDiffusion:
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = ( sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12. # Equation 12.
noise = th.randn_like(x) noise = th.randn_like(x)
mean_pred = ( mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
out["pred_xstart"] * th.sqrt(alpha_bar_prev) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]} return {"sample": sample, "pred_xstart": out["pred_xstart"]}
@ -800,16 +733,12 @@ class GaussianDiffusion:
# Usually our model outputs epsilon, but we re-derive it # Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction. # in case we used x_start or x_prev prediction.
eps = ( eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed # Equation 12. reversed
mean_pred = ( mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ th.sqrt(1 - alpha_bar_next) * eps
)
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
@ -897,9 +826,7 @@ class GaussianDiffusion:
yield out yield out
img = out["sample"] img = out["sample"]
def _vb_terms_bpd( def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
""" """
Get a term for the variational lower-bound. Get a term for the variational lower-bound.
@ -910,15 +837,9 @@ class GaussianDiffusion:
- 'output': a shape [N] tensor of NLLs or KLs. - 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions. - 'pred_xstart': the x_0 predictions.
""" """
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
x_start=x_start, x_t=x_t, t=t out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
) kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0) kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood( decoder_nll = -discretized_gaussian_log_likelihood(
@ -969,7 +890,7 @@ class GaussianDiffusion:
model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
if isinstance(model_outputs, tuple): if isinstance(model_outputs, tuple):
model_output = model_outputs[0] model_output = model_outputs[0]
terms['extra_outputs'] = model_outputs[1:] terms["extra_outputs"] = model_outputs[1:]
else: else:
model_output = model_outputs model_output = model_outputs
@ -996,9 +917,7 @@ class GaussianDiffusion:
terms["vb"] *= self.num_timesteps / 1000.0 terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.PREVIOUS_X: if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance( target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0]
x_start=x_start, x_t=x_t, t=t
)[0]
x_start_pred = torch.zeros(x_start) # Not supported. x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X: elif self.model_mean_type == ModelMeanType.START_X:
target = x_start target = x_start
@ -1020,7 +939,9 @@ class GaussianDiffusion:
return terms return terms
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None): def autoregressive_training_losses(
self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None
):
""" """
Compute training losses for a single timestep. Compute training losses for a single timestep.
@ -1068,9 +989,7 @@ class GaussianDiffusion:
terms["vb"] *= self.num_timesteps / 1000.0 terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.PREVIOUS_X: if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance( target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0]
x_start=x_start, x_t=x_t, t=t
)[0]
x_start_pred = torch.zeros(x_start) # Not supported. x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X: elif self.model_mean_type == ModelMeanType.START_X:
target = x_start target = x_start
@ -1105,9 +1024,7 @@ class GaussianDiffusion:
batch_size = x_start.shape[0] batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl( kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
return mean_flat(kl_prior) / np.log(2.0) return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
@ -1183,9 +1100,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
scale = 1000 / num_diffusion_timesteps scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
return np.linspace( return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine": elif schedule_name == "cosine":
return betas_for_alpha_bar( return betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
@ -1219,19 +1134,13 @@ class SpacedDiffusion(GaussianDiffusion):
kwargs["betas"] = np.array(new_betas) kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs) super().__init__(**kwargs)
def p_mean_variance( def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses( def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs) return super().training_losses(self._wrap_model(model), *args, **kwargs)
def autoregressive_training_losses( def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs): def condition_mean(self, cond_fn, *args, **kwargs):
@ -1244,9 +1153,7 @@ class SpacedDiffusion(GaussianDiffusion):
if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
return model return model
mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
return mod( return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps)
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
def _scale_timesteps(self, t): def _scale_timesteps(self, t):
# Scaling is done by the wrapped model. # Scaling is done by the wrapped model.
@ -1281,9 +1188,7 @@ def space_timesteps(num_timesteps, section_counts):
for i in range(1, num_timesteps): for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count: if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i)) return set(range(0, num_timesteps, i))
raise ValueError( raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")] section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts) size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts) extra = num_timesteps % len(section_counts)
@ -1292,9 +1197,7 @@ def space_timesteps(num_timesteps, section_counts):
for i, section_count in enumerate(section_counts): for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0) size = size_per + (1 if i < extra else 0)
if size < section_count: if size < section_count:
raise ValueError( raise ValueError(f"cannot divide section of {size} steps into {section_count}")
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1: if section_count <= 1:
frac_stride = 1 frac_stride = 1
else: else:
@ -1315,6 +1218,7 @@ class _WrappedModel:
self.timestep_map = timestep_map self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs): def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts] new_ts = map_tensor[ts]
@ -1323,6 +1227,7 @@ class _WrappedModel:
model_output = self.model(x, new_ts, **kwargs) model_output = self.model(x, new_ts, **kwargs)
return model_output return model_output
class _WrappedAutoregressiveModel: class _WrappedAutoregressiveModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model self.model = model
@ -1337,6 +1242,7 @@ class _WrappedAutoregressiveModel:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps) new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, x0, new_ts, **kwargs) return self.model(x, x0, new_ts, **kwargs)
def _extract_into_tensor(arr, timesteps, broadcast_shape): def _extract_into_tensor(arr, timesteps, broadcast_shape):
""" """
Extract values from a 1-D numpy array for a batch of indices. Extract values from a 1-D numpy array for a batch of indices.

View File

@ -29,11 +29,9 @@ def timestep_embedding(timesteps, dim, max_period=10000):
:return: an [N x dim] Tensor of positional embeddings. :return: an [N x dim] Tensor of positional embeddings.
""" """
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
-math.log(max_period) device=timesteps.device
* torch.arange(start=0, end=half, dtype=torch.float32) )
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
@ -98,17 +96,13 @@ class ResBlock(TimestepBlock):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
nn.Conv1d( nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
self.out_channels, self.out_channels, kernel_size, padding=padding
),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
else: else:
self.skip_connection = nn.Conv1d( self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
channels, self.out_channels, eff_kernel, padding=eff_padding
)
def forward(self, x, emb): def forward(self, x, emb):
h = self.in_layers(x) h = self.in_layers(x)
@ -137,9 +131,7 @@ class DiffusionLayer(TimestepBlock):
dims=1, dims=1,
use_scale_shift_norm=True, use_scale_shift_norm=True,
) )
self.attn = AttentionBlock( self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
model_channels, num_heads, relative_pos_embeddings=True
)
def forward(self, x, time_emb): def forward(self, x, time_emb):
y = self.resblk(x, time_emb) y = self.resblk(x, time_emb)
@ -239,16 +231,11 @@ class DiffusionTts(nn.Module):
DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads),
) )
self.integrating_conv = nn.Conv1d( self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1)
model_channels * 2, model_channels, kernel_size=1
)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)]
DiffusionLayer(model_channels, dropout, num_heads)
for _ in range(num_layers)
]
+ [ + [
ResBlock( ResBlock(
model_channels, model_channels,
@ -275,9 +262,7 @@ class DiffusionTts(nn.Module):
+ list(self.code_converter.parameters()) + list(self.code_converter.parameters())
+ list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters())
+ list(self.latent_conditioner.parameters()), + list(self.latent_conditioner.parameters()),
"timestep_integrator": list( "timestep_integrator": list(self.conditioning_timestep_integrator.parameters())
self.conditioning_timestep_integrator.parameters()
)
+ list(self.integrating_conv.parameters()), + list(self.integrating_conv.parameters()),
"time_embed": list(self.time_embed.parameters()), "time_embed": list(self.time_embed.parameters()),
} }
@ -285,9 +270,7 @@ class DiffusionTts(nn.Module):
def get_conditioning(self, conditioning_input): def get_conditioning(self, conditioning_input):
speech_conditioning_input = ( speech_conditioning_input = (
conditioning_input.unsqueeze(1) conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
if len(conditioning_input.shape) == 3
else conditioning_input
) )
conds = [] conds = []
for j in range(speech_conditioning_input.shape[1]): for j in range(speech_conditioning_input.shape[1]):
@ -313,29 +296,20 @@ class DiffusionTts(nn.Module):
else: else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb) code_emb = self.code_converter(code_emb)
code_emb = self.code_norm(code_emb) * ( code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
1 + cond_scale.unsqueeze(-1)
) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros( unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
(code_emb.shape[0], 1, 1), device=code_emb.device
)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0: if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = ( unconditioned_batches = (
torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage
< self.unconditioned_percentage
) )
code_emb = torch.where( code_emb = torch.where(
unconditioned_batches, unconditioned_batches,
self.unconditioned_embedding.repeat( self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
aligned_conditioning.shape[0], 1, 1
),
code_emb, code_emb,
) )
expanded_code_emb = F.interpolate( expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest")
code_emb, size=expected_seq_len, mode="nearest"
)
if not return_code_pred: if not return_code_pred:
return expanded_code_emb return expanded_code_emb
@ -376,10 +350,7 @@ class DiffusionTts(nn.Module):
unused_params = [] unused_params = []
if conditioning_free: if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend( unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
list(self.code_converter.parameters())
+ list(self.code_embedding.parameters())
)
unused_params.extend(list(self.latent_conditioner.parameters())) unused_params.extend(list(self.latent_conditioner.parameters()))
else: else:
if precomputed_aligned_embeddings is not None: if precomputed_aligned_embeddings is not None:
@ -390,8 +361,7 @@ class DiffusionTts(nn.Module):
) )
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
unused_params.extend( unused_params.extend(
list(self.code_converter.parameters()) list(self.code_converter.parameters()) + list(self.code_embedding.parameters())
+ list(self.code_embedding.parameters())
) )
else: else:
unused_params.extend(list(self.latent_conditioner.parameters())) unused_params.extend(list(self.latent_conditioner.parameters()))

View File

@ -1,6 +1,8 @@
import math import math
import torch import torch
class NoiseScheduleVP: class NoiseScheduleVP:
def __init__( def __init__(
self, self,
@ -107,11 +109,7 @@ class NoiseScheduleVP:
log_alphas = 0.5 * torch.log(alphas_cumprod) log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas) self.total_N = len(log_alphas)
self.T = 1.0 self.T = 1.0
self.t_array = ( self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
torch.linspace(0.0, 1.0, self.total_N + 1)[1:]
.reshape((1, -1))
.to(dtype=dtype)
)
self.log_alpha_array = log_alphas.reshape( self.log_alpha_array = log_alphas.reshape(
( (
1, 1,
@ -131,9 +129,7 @@ class NoiseScheduleVP:
/ math.pi / math.pi
- self.cosine_s - self.cosine_s
) )
self.cosine_log_alpha_0 = math.log( self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
)
self.schedule = schedule self.schedule = schedule
if schedule == "cosine": if schedule == "cosine":
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
@ -157,11 +153,7 @@ class NoiseScheduleVP:
elif self.schedule == "cosine": elif self.schedule == "cosine":
def log_alpha_fn(s): def log_alpha_fn(s):
return torch.log( return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
torch.cos(
(s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0
)
)
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t return log_alpha_t
@ -191,17 +183,11 @@ class NoiseScheduleVP:
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
""" """
if self.schedule == "linear": if self.schedule == "linear":
tmp = ( tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
2.0
* (self.beta_1 - self.beta_0)
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
)
Delta = self.beta_0**2 + tmp Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == "discrete": elif self.schedule == "discrete":
log_alpha = -0.5 * torch.logaddexp( log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
torch.zeros((1,)).to(lamb.device), -2.0 * lamb
)
t = interpolate_fn( t = interpolate_fn(
log_alpha.reshape((-1, 1)), log_alpha.reshape((-1, 1)),
torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
@ -345,14 +331,10 @@ def model_wrapper(
if model_type == "noise": if model_type == "noise":
return output return output
elif model_type == "x_start": elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha( alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
t_continuous
), noise_schedule.marginal_std(t_continuous)
return (x - alpha_t * output) / sigma_t return (x - alpha_t * output) / sigma_t
elif model_type == "v": elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha( alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
t_continuous
), noise_schedule.marginal_std(t_continuous)
return alpha_t * output + sigma_t * x return alpha_t * output + sigma_t * x
elif model_type == "score": elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous) sigma_t = noise_schedule.marginal_std(t_continuous)
@ -482,9 +464,7 @@ class DPM_Solver:
p = self.dynamic_thresholding_ratio p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims( s = expand_dims(
torch.maximum( torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)),
s, self.thresholding_max_val * torch.ones_like(s).to(s.device)
),
dims, dims,
) )
x0 = torch.clamp(x0, -s, s) / s x0 = torch.clamp(x0, -s, s) / s
@ -501,9 +481,7 @@ class DPM_Solver:
Return the data prediction model (with corrector). Return the data prediction model (with corrector).
""" """
noise = self.noise_prediction_fn(x, t) noise = self.noise_prediction_fn(x, t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha( alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
t
), self.noise_schedule.marginal_std(t)
x0 = (x - sigma_t * noise) / alpha_t x0 = (x - sigma_t * noise) / alpha_t
if self.correcting_x0_fn is not None: if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0, t) x0 = self.correcting_x0_fn(x0, t)
@ -536,30 +514,20 @@ class DPM_Solver:
if skip_type == "logSNR": if skip_type == "logSNR":
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace( logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps) return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == "time_uniform": elif skip_type == "time_uniform":
return torch.linspace(t_T, t_0, N + 1).to(device) return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == "time_quadratic": elif skip_type == "time_quadratic":
t_order = 2 t_order = 2
t = ( t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
.pow(t_order)
.to(device)
)
return t return t
else: else:
raise ValueError( raise ValueError(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format( "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
skip_type
)
) )
def get_orders_and_timesteps_for_singlestep_solver( def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
self, steps, order, skip_type, t_T, t_0, device
):
""" """
Get the order of each step for sampling by the singlestep DPM-Solver. Get the order of each step for sampling by the singlestep DPM-Solver.
@ -664,9 +632,7 @@ class DPM_Solver:
dims = x.dim() dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff( log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
s
), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t) alpha_t = torch.exp(log_alpha_t)
@ -716,11 +682,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if solver_type not in ["dpmsolver", "taylor"]: if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError( raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
if r1 is None: if r1 is None:
r1 = 0.5 r1 = 0.5
ns = self.noise_schedule ns = self.noise_schedule
@ -766,10 +728,7 @@ class DPM_Solver:
if model_s is None: if model_s is None:
model_s = self.model_fn(x, s) model_s = self.model_fn(x, s)
x_s1 = ( x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
torch.exp(log_alpha_s1 - log_alpha_s) * x
- (sigma_s1 * phi_11) * model_s
)
model_s1 = self.model_fn(x_s1, s1) model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpmsolver": if solver_type == "dpmsolver":
x_t = ( x_t = (
@ -820,11 +779,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if solver_type not in ["dpmsolver", "taylor"]: if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError( raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
if r1 is None: if r1 is None:
r1 = 1.0 / 3.0 r1 = 1.0 / 3.0
if r2 is None: if r2 is None:
@ -901,9 +856,7 @@ class DPM_Solver:
if model_s is None: if model_s is None:
model_s = self.model_fn(x, s) model_s = self.model_fn(x, s)
if model_s1 is None: if model_s1 is None:
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - ( x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
sigma_s1 * phi_11
) * model_s
model_s1 = self.model_fn(x_s1, s1) model_s1 = self.model_fn(x_s1, s1)
x_s2 = ( x_s2 = (
(torch.exp(log_alpha_s2 - log_alpha_s)) * x (torch.exp(log_alpha_s2 - log_alpha_s)) * x
@ -934,9 +887,7 @@ class DPM_Solver:
else: else:
return x_t return x_t
def multistep_dpm_solver_second_update( def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
):
""" """
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
@ -951,11 +902,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if solver_type not in ["dpmsolver", "taylor"]: if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError( raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
ns = self.noise_schedule ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
@ -964,9 +911,7 @@ class DPM_Solver:
ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t), ns.marginal_lambda(t),
) )
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff( log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
t_prev_0
), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t) alpha_t = torch.exp(log_alpha_t)
@ -977,11 +922,7 @@ class DPM_Solver:
if self.algorithm_type == "dpmsolver++": if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h) phi_1 = torch.expm1(-h)
if solver_type == "dpmsolver": if solver_type == "dpmsolver":
x_t = ( x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
- 0.5 * (alpha_t * phi_1) * D1_0
)
elif solver_type == "taylor": elif solver_type == "taylor":
x_t = ( x_t = (
(sigma_t / sigma_prev_0) * x (sigma_t / sigma_prev_0) * x
@ -1004,9 +945,7 @@ class DPM_Solver:
) )
return x_t return x_t
def multistep_dpm_solver_third_update( def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
):
""" """
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
@ -1029,9 +968,7 @@ class DPM_Solver:
ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t), ns.marginal_lambda(t),
) )
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff( log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
t_prev_0
), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t) alpha_t = torch.exp(log_alpha_t)
@ -1093,9 +1030,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if order == 1: if order == 1:
return self.dpm_solver_first_update( return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
x, s, t, return_intermediate=return_intermediate
)
elif order == 2: elif order == 2:
return self.singlestep_dpm_solver_second_update( return self.singlestep_dpm_solver_second_update(
x, x,
@ -1118,9 +1053,7 @@ class DPM_Solver:
else: else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def multistep_dpm_solver_update( def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"
):
""" """
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
@ -1136,17 +1069,11 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if order == 1: if order == 1:
return self.dpm_solver_first_update( return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
x, t_prev_list[-1], t, model_s=model_prev_list[-1]
)
elif order == 2: elif order == 2:
return self.multistep_dpm_solver_second_update( return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
x, model_prev_list, t_prev_list, t, solver_type=solver_type
)
elif order == 3: elif order == 3:
return self.multistep_dpm_solver_third_update( return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
x, model_prev_list, t_prev_list, t, solver_type=solver_type
)
else: else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
@ -1198,9 +1125,7 @@ class DPM_Solver:
return self.dpm_solver_first_update(x, s, t, return_intermediate=True) return self.dpm_solver_first_update(x, s, t, return_intermediate=True)
def higher_update(x, s, t, **kwargs): def higher_update(x, s, t, **kwargs):
return self.singlestep_dpm_solver_second_update( return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
x, s, t, r1=r1, solver_type=solver_type, **kwargs
)
elif order == 3: elif order == 3:
r1, r2 = 1.0 / 3.0, 2.0 / 3.0 r1, r2 = 1.0 / 3.0, 2.0 / 3.0
@ -1211,16 +1136,10 @@ class DPM_Solver:
) )
def higher_update(x, s, t, **kwargs): def higher_update(x, s, t, **kwargs):
return self.singlestep_dpm_solver_third_update( return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
)
else: else:
raise ValueError( raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
"For adaptive step size solver, order must be 2 or 3, got {}".format(
order
)
)
while torch.abs((s - t_0)).mean() > t_err: while torch.abs((s - t_0)).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h) t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t) x_lower, lower_noise_kwargs = lower_update(x, s, t)
@ -1231,9 +1150,7 @@ class DPM_Solver:
) )
def norm_fn(v): def norm_fn(v):
return torch.sqrt( return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
)
E = norm_fn((x_higher - x_lower) / delta).max() E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.0): if torch.all(E <= 1.0):
@ -1259,9 +1176,7 @@ class DPM_Solver:
Returns: Returns:
xt with shape `(t_size, batch_size, *shape)`. xt with shape `(t_size, batch_size, *shape)`.
""" """
alpha_t, sigma_t = self.noise_schedule.marginal_alpha( alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
t
), self.noise_schedule.marginal_std(t)
if noise is None: if noise is None:
noise = torch.randn((t.shape[0], *x.shape), device=x.device) noise = torch.randn((t.shape[0], *x.shape), device=x.device)
x = x.reshape((-1, *x.shape)) x = x.reshape((-1, *x.shape))
@ -1468,9 +1383,7 @@ class DPM_Solver:
) )
elif method == "multistep": elif method == "multistep":
assert steps >= order assert steps >= order
timesteps = self.get_time_steps( timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
)
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# Init the initial values. # Init the initial values.
step = 0 step = 0
@ -1527,10 +1440,7 @@ class DPM_Solver:
model_prev_list[-1] = self.model_fn(x, t) model_prev_list[-1] = self.model_fn(x, t)
elif method in ["singlestep", "singlestep_fixed"]: elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep": if method == "singlestep":
( (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
timesteps_outer,
orders,
) = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps, steps=steps,
order=order, order=order,
skip_type=skip_type, skip_type=skip_type,
@ -1543,9 +1453,7 @@ class DPM_Solver:
orders = [ orders = [
order, order,
] * K ] * K
timesteps_outer = self.get_time_steps( timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
)
for step, order in enumerate(orders): for step, order in enumerate(orders):
s, t = timesteps_outer[step], timesteps_outer[step + 1] s, t = timesteps_outer[step], timesteps_outer[step + 1]
timesteps_inner = self.get_time_steps( timesteps_inner = self.get_time_steps(
@ -1559,9 +1467,7 @@ class DPM_Solver:
h = lambda_inner[-1] - lambda_inner[0] h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update( x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
x, s, t, order, solver_type=solver_type, r1=r1, r2=r2
)
if self.correcting_xt_fn is not None: if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step) x = self.correcting_xt_fn(x, t, step)
if return_intermediate: if return_intermediate:
@ -1613,9 +1519,7 @@ def interpolate_fn(x, xp, yp):
cand_start_idx, cand_start_idx,
), ),
) )
end_idx = torch.where( end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where( start_idx2 = torch.where(
@ -1628,12 +1532,8 @@ def interpolate_fn(x, xp, yp):
), ),
) )
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather( start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2) end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
).squeeze(2)
end_y = torch.gather(
y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand return cand

View File

@ -40,8 +40,7 @@ class RandomLatentConverter(nn.Module):
def __init__(self, channels): def __init__(self, channels):
super().__init__() super().__init__()
self.layers = nn.Sequential( self.layers = nn.Sequential(
*[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels)
nn.Linear(channels, channels)
) )
self.channels = channels self.channels = channels

View File

@ -95,9 +95,7 @@ def _expand_number(m):
elif num % 100 == 0: elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred" return _inflect.number_to_words(num // 100) + " hundred"
else: else:
return _inflect.number_to_words( return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
num, andword="", zero="oh", group=2
).replace(", ", " ")
else: else:
return _inflect.number_to_words(num, andword="") return _inflect.number_to_words(num, andword="")
@ -165,9 +163,7 @@ def lev_distance(s1, s2):
if c1 == c2: if c1 == c2:
distances_.append(distances[i1]) distances_.append(distances[i1])
else: else:
distances_.append( distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
)
distances = distances_ distances = distances_
return distances[-1] return distances[-1]

View File

@ -36,12 +36,8 @@ def route_args(router, args, depth):
for key in matched_keys: for key in matched_keys:
val = args[key] val = args[key]
for depth, ((f_args, g_args), routes) in enumerate( for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
zip(routed_args, router[key]) new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
):
new_f_args, new_g_args = map(
lambda route: ({key: val} if route else {}), routes
)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args return routed_args
@ -217,12 +213,8 @@ class Transformer(nn.Module):
layers.append( layers.append(
nn.ModuleList( nn.ModuleList(
[ [
LayerScale( LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)),
dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm) LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)),
),
LayerScale(
dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)
),
] ]
) )
) )

View File

@ -1,5 +1,7 @@
import os import os
try: import gdown
try:
import gdown
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Sorry, gdown is required in order to download the new BigVGAN vocoder.\n" "Sorry, gdown is required in order to download the new BigVGAN vocoder.\n"
@ -11,9 +13,7 @@ import progressbar
D_STEM = "https://drive.google.com/uc?id=" D_STEM = "https://drive.google.com/uc?id="
DEFAULT_MODELS_DIR = os.path.join( DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
os.path.expanduser("~"), ".cache", "tortoise", "models"
)
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS = { MODELS = {
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth", "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
@ -30,6 +30,8 @@ MODELS = {
} }
pbar = None pbar = None
def download_models(specific_models=None): def download_models(specific_models=None):
""" """
Call to download all the models that Tortoise uses. Call to download all the models that Tortoise uses.
@ -62,6 +64,7 @@ def download_models(specific_models=None):
request.urlretrieve(url, model_path, show_progress) request.urlretrieve(url, model_path, show_progress)
print("Done.") print("Done.")
def get_model_path(model_name, models_dir=MODELS_DIR): def get_model_path(model_name, models_dir=MODELS_DIR):
""" """
Get path to given model, download it if it doesn't exist. Get path to given model, download it if it doesn't exist.

View File

@ -1,12 +1,12 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import json
from enum import Enum
from typing import Optional, Callable
from dataclasses import dataclass
MAX_WAV_VALUE = 32768.0 MAX_WAV_VALUE = 32768.0
@ -40,18 +40,12 @@ class KernelPredictor(torch.nn.Module):
self.conv_kernel_size = conv_kernel_size self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers self.conv_layers = conv_layers
kpnet_kernel_channels = ( kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
) # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential( self.input_conv = nn.Sequential(
nn.utils.weight_norm( nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
),
getattr(nn, kpnet_nonlinear_activation)(
**kpnet_nonlinear_activation_params
),
) )
self.residual_convs = nn.ModuleList() self.residual_convs = nn.ModuleList()
@ -69,9 +63,7 @@ class KernelPredictor(torch.nn.Module):
bias=True, bias=True,
) )
), ),
getattr(nn, kpnet_nonlinear_activation)( getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
**kpnet_nonlinear_activation_params
),
nn.utils.weight_norm( nn.utils.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
@ -81,9 +73,7 @@ class KernelPredictor(torch.nn.Module):
bias=True, bias=True,
) )
), ),
getattr(nn, kpnet_nonlinear_activation)( getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
**kpnet_nonlinear_activation_params
),
) )
) )
self.kernel_conv = nn.utils.weight_norm( self.kernel_conv = nn.utils.weight_norm(
@ -252,17 +242,11 @@ class LVCBlock(torch.nn.Module):
""" """
batch, _, in_length = x.shape batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == ( assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
kernel_length * hop_size
), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2) padding = dilation * int((kernel_size - 1) / 2)
x = F.pad( x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x, (padding, padding), "constant", 0 x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(
2, hop_size + 2 * padding, hop_size
) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation: if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0) x = F.pad(x, (0, dilation), "constant", 0)
@ -270,12 +254,8 @@ class LVCBlock(torch.nn.Module):
3, dilation, dilation 3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size] x = x[:, :, :, :, :hop_size]
x = x.transpose( x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
3, 4 x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(
4, kernel_size, 1
) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel) o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d) o = o.to(memory_format=torch.channels_last_3d)
@ -334,15 +314,11 @@ class UnivNetGenerator(nn.Module):
) )
) )
self.conv_pre = nn.utils.weight_norm( self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
)
self.conv_post = nn.Sequential( self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope), nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm( nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")
),
nn.Tanh(), nn.Tanh(),
) )
@ -399,12 +375,16 @@ class VocType:
constructor: Callable[[], nn.Module] constructor: Callable[[], nn.Module]
model_path: str model_path: str
subkey: Optional[str] = None subkey: Optional[str] = None
def optionally_index(self, model_dict): def optionally_index(self, model_dict):
if self.subkey is not None: if self.subkey is not None:
return model_dict[self.subkey] return model_dict[self.subkey]
return model_dict return model_dict
class VocConf(Enum): class VocConf(Enum):
Univnet = VocType(UnivNetGenerator, "vocoder.pth", 'model_g') Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g")
if __name__ == "__main__": if __name__ == "__main__":
model = UnivNetGenerator() model = UnivNetGenerator()

View File

@ -12,9 +12,7 @@ def max_alignment(s1, s2, skip_character="~", record=None):
""" """
if record is None: if record is None:
record = {} record = {}
assert ( assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
skip_character not in s1
), f"Found the skip character {skip_character} in the provided string, {s1}"
if len(s1) == 0: if len(s1) == 0:
return "" return ""
if len(s2) == 0: if len(s2) == 0:
@ -49,15 +47,9 @@ class Wav2VecAlignment:
""" """
def __init__(self, device="cuda"): def __init__(self, device="cuda"):
self.model = Wav2Vec2ForCTC.from_pretrained( self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
"jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli" self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h")
).cpu() self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols")
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-large-960h"
)
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
"jbetker/tacotron-symbols"
)
self.device = device self.device = device
def align(self, audio, expected_text, audio_sample_rate=24000): def align(self, audio, expected_text, audio_sample_rate=24000):
@ -117,9 +109,7 @@ class Wav2VecAlignment:
) )
# Now fix up alignments. Anything with -1 should be interpolated. # Now fix up alignments. Anything with -1 should be interpolated.
alignments.append( alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
orig_len
) # This'll get removed but makes the algorithm below more readable.
for i in range(len(alignments)): for i in range(len(alignments)):
if alignments[i] == -1: if alignments[i] == -1:
for j in range(i + 1, len(alignments)): for j in range(i + 1, len(alignments)):
@ -128,9 +118,7 @@ class Wav2VecAlignment:
break break
for j in range(i, next_found_token): for j in range(i, next_found_token):
gap = alignments[next_found_token] - alignments[i - 1] gap = alignments[next_found_token] - alignments[i - 1]
alignments[j] = (j - i + 1) * gap // ( alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1]
next_found_token - i + 1
) + alignments[i - 1]
return alignments[:-1] return alignments[:-1]
@ -140,9 +128,7 @@ class Wav2VecAlignment:
splitted = expected_text.split("[") splitted = expected_text.split("[")
fully_split = [splitted[0]] fully_split = [splitted[0]]
for spl in splitted[1:]: for spl in splitted[1:]:
assert ( assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.'
"]" in spl
), 'Every "[" character must be paired with a "]" with no nesting.'
fully_split.extend(spl.split("]")) fully_split.extend(spl.split("]"))
# At this point, fully_split is a list of strings, with every other string being something that should be redacted. # At this point, fully_split is a list of strings, with every other string being something that should be redacted.

View File

@ -6,24 +6,25 @@ from inspect import isfunction
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn, einsum from torch import einsum, nn
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [ Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
'pre_softmax_attn',
'post_softmax_attn'
])
LayerIntermediates = namedtuple('Intermediates', [ LayerIntermediates = namedtuple(
'hiddens', "Intermediates",
'attn_intermediates', [
'past_key_values', "hiddens",
]) "attn_intermediates",
"past_key_values",
],
)
# helpers # helpers
def exists(val): def exists(val):
return val is not None return val is not None
@ -38,7 +39,7 @@ def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth return val if isinstance(val, tuple) else (val,) * depth
class always(): class always:
def __init__(self, val): def __init__(self, val):
self.val = val self.val = val
@ -46,7 +47,7 @@ class always():
return self.val return self.val
class not_equals(): class not_equals:
def __init__(self, val): def __init__(self, val):
self.val = val self.val = val
@ -54,7 +55,7 @@ class not_equals():
return x != self.val return x != self.val
class equals(): class equals:
def __init__(self, val): def __init__(self, val):
self.val = val self.val = val
@ -72,14 +73,16 @@ def l2norm(t):
# init helpers # init helpers
def init_zero_(layer): def init_zero_(layer):
nn.init.constant_(layer.weight, 0.) nn.init.constant_(layer.weight, 0.0)
if exists(layer.bias): if exists(layer.bias):
nn.init.constant_(layer.bias, 0.) nn.init.constant_(layer.bias, 0.0)
# keyword argument helpers # keyword argument helpers
def pick_and_pop(keys, d): def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys)) values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values)) return dict(zip(keys, values))
@ -104,12 +107,13 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d): def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs return kwargs_without_prefix, kwargs
# activations # activations
class ReluSquared(nn.Module): class ReluSquared(nn.Module):
def forward(self, x): def forward(self, x):
return F.relu(x) ** 2 return F.relu(x) ** 2
@ -117,30 +121,31 @@ class ReluSquared(nn.Module):
# positional embeddings # positional embeddings
class AbsolutePositionalEmbedding(nn.Module): class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len): def __init__(self, dim, max_seq_len):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim**-0.5
self.emb = nn.Embedding(max_seq_len, dim) self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x): def forward(self, x):
n = torch.arange(x.shape[1], device=x.device) n = torch.arange(x.shape[1], device=x.device)
pos_emb = self.emb(n) pos_emb = self.emb(n)
pos_emb = rearrange(pos_emb, 'n d -> () n d') pos_emb = rearrange(pos_emb, "n d -> () n d")
return pos_emb * self.scale return pos_emb * self.scale
class FixedPositionalEmbedding(nn.Module): class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq) self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_dim=1, offset=0): def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return rearrange(emb, 'n d -> () n d') return rearrange(emb, "n d -> () n d")
class RelativePositionBias(nn.Module): class RelativePositionBias(nn.Module):
@ -166,9 +171,10 @@ class RelativePositionBias(nn.Module):
max_exact = num_buckets // 2 max_exact = num_buckets // 2
is_small = n < max_exact is_small = n < max_exact
val_if_large = max_exact + ( val_if_large = (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) max_exact
).long() + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large) ret += torch.where(is_small, n, val_if_large)
@ -179,10 +185,11 @@ class RelativePositionBias(nn.Module):
q_pos = torch.arange(i, dtype=torch.long, device=device) q_pos = torch.arange(i, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device) k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = k_pos[None, :] - q_pos[:, None] rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, rp_bucket = self._relative_position_bucket(
max_distance=self.max_distance) rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance
)
values = self.relative_attention_bias(rp_bucket) values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j') bias = rearrange(values, "i j h -> () h i j")
return qk_dots + (bias * self.scale) return qk_dots + (bias * self.scale)
@ -191,23 +198,25 @@ class AlibiPositionalBias(nn.Module):
super().__init__() super().__init__()
self.heads = heads self.heads = heads
slopes = torch.Tensor(self._get_slopes(heads)) slopes = torch.Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> () h () ()') slopes = rearrange(slopes, "h -> () h () ()")
self.register_buffer('slopes', slopes, persistent=False) self.register_buffer("slopes", slopes, persistent=False)
self.register_buffer('bias', None, persistent=False) self.register_buffer("bias", None, persistent=False)
@staticmethod @staticmethod
def _get_slopes(heads): def _get_slopes(heads):
def get_slopes_power_of_2(n): def get_slopes_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3))) start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start ratio = start
return [start * ratio ** i for i in range(n)] return [start * ratio**i for i in range(n)]
if math.log2(heads).is_integer(): if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads) return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads)) closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ return (
:heads - closest_power_of_2] get_slopes_power_of_2(closest_power_of_2)
+ get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: heads - closest_power_of_2]
)
def forward(self, qk_dots): def forward(self, qk_dots):
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
@ -216,13 +225,13 @@ class AlibiPositionalBias(nn.Module):
return qk_dots + self.bias[..., :j] return qk_dots + self.bias[..., :j]
bias = torch.arange(j, device=device) bias = torch.arange(j, device=device)
bias = rearrange(bias, 'j -> () () () j') bias = rearrange(bias, "j -> () () () j")
bias = bias * self.slopes bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[1] num_heads_unalibied = h - bias.shape[1]
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
self.register_buffer('bias', bias, persistent=False) self.register_buffer("bias", bias, persistent=False)
return qk_dots + self.bias return qk_dots + self.bias
@ -247,8 +256,8 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
else: else:
i_arange = torch.arange(i, device=device) i_arange = torch.arange(i, device=device)
j_arange = torch.arange(j, device=device) j_arange = torch.arange(j, device=device)
bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1') bias = rearrange(j_arange, "j -> 1 1 1 j") - rearrange(i_arange, "i -> 1 1 i 1")
self.register_buffer('bias', bias, persistent=False) self.register_buffer("bias", bias, persistent=False)
if self.bidirectional: if self.bidirectional:
past_slopes = get_slopes(self.learned_logslopes) past_slopes = get_slopes(self.learned_logslopes)
@ -264,18 +273,18 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq) self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, device): def forward(self, max_seq_len, device):
t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq) t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq) freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
return rearrange(emb, 'n d -> () () n d') return rearrange(emb, "n d -> () () n d")
def rotate_half(x): def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j=2) x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2) x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
@ -288,6 +297,7 @@ def apply_rotary_pos_emb(t, freqs):
# norms # norms
class Scale(nn.Module): class Scale(nn.Module):
def __init__(self, value, fn): def __init__(self, value, fn):
super().__init__() super().__init__()
@ -323,7 +333,7 @@ class Rezero(nn.Module):
class ScaleNorm(nn.Module): class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5): def __init__(self, dim, eps=1e-5):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim**-0.5
self.eps = eps self.eps = eps
self.g = nn.Parameter(torch.ones(1)) self.g = nn.Parameter(torch.ones(1))
@ -335,7 +345,7 @@ class ScaleNorm(nn.Module):
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8): def __init__(self, dim, eps=1e-8):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim**-0.5
self.eps = eps self.eps = eps
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
@ -347,7 +357,7 @@ class RMSNorm(nn.Module):
class RMSScaleShiftNorm(nn.Module): class RMSScaleShiftNorm(nn.Module):
def __init__(self, dim, eps=1e-8): def __init__(self, dim, eps=1e-8):
super().__init__() super().__init__()
self.scale = dim ** -0.5 self.scale = dim**-0.5
self.eps = eps self.eps = eps
self.g = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim))
self.scale_shift_process = nn.Linear(dim * 2, dim * 2) self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
@ -364,6 +374,7 @@ class RMSScaleShiftNorm(nn.Module):
# residual and residual gates # residual and residual gates
class Residual(nn.Module): class Residual(nn.Module):
def __init__(self, dim, scale_residual=False): def __init__(self, dim, scale_residual=False):
super().__init__() super().__init__()
@ -386,24 +397,22 @@ class GRUGating(nn.Module):
if exists(self.residual_scale): if exists(self.residual_scale):
residual = residual * self.residual_scale residual = residual * self.residual_scale
gated_output = self.gru( gated_output = self.gru(rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d"))
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)
return gated_output.reshape_as(x) return gated_output.reshape_as(x)
# token shifting # token shifting
def shift(t, amount, mask=None): def shift(t, amount, mask=None):
if amount == 0: if amount == 0:
return t return t
if exists(mask): if exists(mask):
t = t.masked_fill(~mask[..., None], 0.) t = t.masked_fill(~mask[..., None], 0.0)
return F.pad(t, (0, 0, amount, -amount), value=0.) return F.pad(t, (0, 0, amount, -amount), value=0.0)
class ShiftTokens(nn.Module): class ShiftTokens(nn.Module):
@ -413,7 +422,7 @@ class ShiftTokens(nn.Module):
self.shifts = tuple(shifts) self.shifts = tuple(shifts)
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
mask = kwargs.get('mask', None) mask = kwargs.get("mask", None)
shifts = self.shifts shifts = self.shifts
segments = len(shifts) segments = len(shifts)
feats_per_shift = x.shape[-1] // segments feats_per_shift = x.shape[-1] // segments
@ -426,6 +435,7 @@ class ShiftTokens(nn.Module):
# feedforward # feedforward
class GLU(nn.Module): class GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation): def __init__(self, dim_in, dim_out, activation):
super().__init__() super().__init__()
@ -446,24 +456,23 @@ class FeedForward(nn.Module):
glu=False, glu=False,
relu_squared=False, relu_squared=False,
post_act_ln=False, post_act_ln=False,
dropout=0., dropout=0.0,
zero_init_output=False zero_init_output=False,
): ):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
activation = ReluSquared() if relu_squared else nn.GELU() activation = ReluSquared() if relu_squared else nn.GELU()
project_in = nn.Sequential( project_in = (
nn.Linear(dim, inner_dim), nn.Sequential(nn.Linear(dim, inner_dim), activation) if not glu else GLU(dim, inner_dim, activation)
activation )
) if not glu else GLU(dim, inner_dim, activation)
self.net = nn.Sequential( self.net = nn.Sequential(
project_in, project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out) nn.Linear(inner_dim, dim_out),
) )
# init last linear layer to 0 # init last linear layer to 0
@ -476,6 +485,7 @@ class FeedForward(nn.Module):
# attention. # attention.
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
@ -486,11 +496,11 @@ class Attention(nn.Module):
talking_heads=False, talking_heads=False,
head_scale=False, head_scale=False,
collab_heads=False, collab_heads=False,
collab_compression=.3, collab_compression=0.3,
sparse_topk=None, sparse_topk=None,
use_entmax15=False, use_entmax15=False,
num_mem_kv=0, num_mem_kv=0,
dropout=0., dropout=0.0,
on_attn=False, on_attn=False,
gate_values=False, gate_values=False,
zero_init_output=False, zero_init_output=False,
@ -502,7 +512,7 @@ class Attention(nn.Module):
rel_pos_max_distance=128, rel_pos_max_distance=128,
): ):
super().__init__() super().__init__()
self.scale = dim_head ** -0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.causal = causal self.causal = causal
@ -532,8 +542,9 @@ class Attention(nn.Module):
# cosine sim attention # cosine sim attention
self.qk_norm = qk_norm self.qk_norm = qk_norm
if qk_norm: if qk_norm:
scale_init_value = default(scale_init_value, scale_init_value = default(
-3) # if not provided, initialize as though it were sequence length of 1024 scale_init_value, -3
) # if not provided, initialize as though it were sequence length of 1024
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
# talking heads # talking heads
@ -565,9 +576,16 @@ class Attention(nn.Module):
self.rel_pos_bias = rel_pos_bias self.rel_pos_bias = rel_pos_bias
if rel_pos_bias: if rel_pos_bias:
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' assert (
self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads, rel_pos_num_buckets <= rel_pos_max_distance
num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance) ), "number of relative position buckets must be less than the relative position max distance"
self.rel_pos = RelativePositionBias(
scale=dim_head**0.5,
causal=causal,
heads=heads,
num_buckets=rel_pos_num_buckets,
max_distance=rel_pos_max_distance,
)
# init output projection 0 # init output projection 0
if zero_init_output: if zero_init_output:
@ -586,8 +604,16 @@ class Attention(nn.Module):
mem=None, mem=None,
layer_past=None, layer_past=None,
): ):
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists( b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = (
context) *x.shape,
self.heads,
self.talking_heads,
self.collab_heads,
self.head_scale,
self.scale,
x.device,
exists(context),
)
kv_input = default(context, x) kv_input = default(context, x)
q_input = x q_input = x
@ -609,11 +635,11 @@ class Attention(nn.Module):
v = self.to_v(v_input) v = self.to_v(v_input)
if not collab_heads: if not collab_heads:
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
else: else:
q = einsum('b i d, h d -> b h i d', q, self.collab_mixing) q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
k = rearrange(k, 'b n d -> b () n d') k = rearrange(k, "b n d -> b () n d")
v = rearrange(v, 'b n (h d) -> b h n d', h=h) v = rearrange(v, "b n (h d) -> b h n d", h=h)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past past_key, past_value = layer_past
@ -633,12 +659,12 @@ class Attention(nn.Module):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()') q_mask = rearrange(q_mask, "b i -> b () i ()")
k_mask = rearrange(k_mask, 'b j -> b () () j') k_mask = rearrange(k_mask, "b j -> b () () j")
input_mask = q_mask * k_mask input_mask = q_mask * k_mask
if self.num_mem_kv > 0: if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2) k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2) v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask): if exists(input_mask):
@ -651,7 +677,7 @@ class Attention(nn.Module):
q, k = map(l2norm, (q, k)) q, k = map(l2norm, (q, k))
scale = 1 / (self.scale.exp().clamp(min=1e-2)) scale = 1 / (self.scale.exp().clamp(min=1e-2))
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale dots = einsum("b h i d, b h j d -> b h i j", q, k) * scale
mask_value = max_neg_value(dots) mask_value = max_neg_value(dots)
if exists(prev_attn): if exists(prev_attn):
@ -660,7 +686,7 @@ class Attention(nn.Module):
pre_softmax_attn = dots.clone() pre_softmax_attn = dots.clone()
if talking_heads: if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() dots = einsum("b h i j, h k -> b k i j", dots, self.pre_softmax_proj).contiguous()
if self.rel_pos_bias: if self.rel_pos_bias:
dots = self.rel_pos(dots) dots = self.rel_pos(dots)
@ -670,18 +696,20 @@ class Attention(nn.Module):
del input_mask del input_mask
if exists(attn_mask): if exists(attn_mask):
assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4' assert (
2 <= attn_mask.ndim <= 4
), "attention mask must have greater than 2 dimensions but less than or equal to 4"
if attn_mask.ndim == 2: if attn_mask.ndim == 2:
attn_mask = rearrange(attn_mask, 'i j -> () () i j') attn_mask = rearrange(attn_mask, "i j -> () () i j")
elif attn_mask.ndim == 3: elif attn_mask.ndim == 3:
attn_mask = rearrange(attn_mask, 'h i j -> () h i j') attn_mask = rearrange(attn_mask, "h i j -> () h i j")
dots.masked_fill_(~attn_mask, mask_value) dots.masked_fill_(~attn_mask, mask_value)
if exists(self.max_attend_past): if exists(self.max_attend_past):
i, j = dots.shape[-2:] i, j = dots.shape[-2:]
range_q = torch.arange(j - i, j, device=device) range_q = torch.arange(j - i, j, device=device)
range_k = torch.arange(j, device=device) range_k = torch.arange(j, device=device)
dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j') dist = rearrange(range_q, "i -> () () i ()") - rearrange(range_k, "j -> () () () j")
mask = dist > self.max_attend_past mask = dist > self.max_attend_past
dots.masked_fill_(mask, mask_value) dots.masked_fill_(mask, mask_value)
del mask del mask
@ -689,7 +717,7 @@ class Attention(nn.Module):
if self.causal: if self.causal:
i, j = dots.shape[-2:] i, j = dots.shape[-2:]
r = torch.arange(i, device=device) r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
mask = F.pad(mask, (j - i, 0), value=False) mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value) dots.masked_fill_(mask, mask_value)
del mask del mask
@ -707,23 +735,20 @@ class Attention(nn.Module):
attn = self.dropout(attn) attn = self.dropout(attn)
if talking_heads: if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() attn = einsum("b h i j, h k -> b k i j", attn, self.post_softmax_proj).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v) out = einsum("b h i j, b h j d -> b h i d", attn, v)
if head_scale: if head_scale:
out = out * self.head_scale_params out = out * self.head_scale_params
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, "b h n d -> b n (h d)")
if exists(self.to_v_gate): if exists(self.to_v_gate):
gates = self.to_v_gate(x) gates = self.to_v_gate(x)
out = out * gates.sigmoid() out = out * gates.sigmoid()
intermediates = Intermediates( intermediates = Intermediates(pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn)
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
)
return self.to_out(out), intermediates, k_cache, v_cache return self.to_out(out), intermediates, k_cache, v_cache
@ -761,20 +786,20 @@ class AttentionLayers(nn.Module):
use_qk_norm_attn=False, use_qk_norm_attn=False,
qk_norm_attn_seq_len=None, qk_norm_attn_seq_len=None,
zero_init_branch_output=False, zero_init_branch_output=False,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
self.dim = dim self.dim = dim
self.depth = depth self.depth = depth
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.causal = causal self.causal = causal
rel_pos_bias = 'rel_pos_bias' in attn_kwargs rel_pos_bias = "rel_pos_bias" in attn_kwargs
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
@ -782,17 +807,18 @@ class AttentionLayers(nn.Module):
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
assert not ( assert not (
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' alibi_pos_bias and rel_pos_bias
), "you can only choose Alibi positional bias or T5 relative positional bias, not both"
if alibi_pos_bias: if alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads) alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' assert alibi_num_heads <= heads, "number of ALiBi heads must be less than the total number of heads"
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
else: else:
self.rel_pos = None self.rel_pos = None
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' assert not (not pre_norm and sandwich_norm), "sandwich norm cannot be used when not using prenorm"
self.pre_norm = pre_norm self.pre_norm = pre_norm
self.sandwich_norm = sandwich_norm self.sandwich_norm = sandwich_norm
@ -809,27 +835,30 @@ class AttentionLayers(nn.Module):
branch_fn = Rezero if use_rezero else None branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross: if cross_attend and not only_cross:
default_block = ('a', 'c', 'f') default_block = ("a", "c", "f")
elif cross_attend and only_cross: elif cross_attend and only_cross:
default_block = ('c', 'f') default_block = ("c", "f")
else: else:
default_block = ('a', 'f') default_block = ("a", "f")
if macaron: if macaron:
default_block = ('f',) + default_block default_block = ("f",) + default_block
# qk normalization # qk normalization
if use_qk_norm_attn: if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists( attn_scale_init_value = (
qk_norm_attn_seq_len) else None -math.log(math.log2(qk_norm_attn_seq_len**2 - qk_norm_attn_seq_len))
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} if exists(qk_norm_attn_seq_len)
else None
)
attn_kwargs = {**attn_kwargs, "qk_norm": True, "scale_init_value": attn_scale_init_value}
# zero init # zero init
if zero_init_branch_output: if zero_init_branch_output:
attn_kwargs = {**attn_kwargs, 'zero_init_output': True} attn_kwargs = {**attn_kwargs, "zero_init_output": True}
ff_kwargs = {**ff_kwargs, 'zero_init_output': True} ff_kwargs = {**ff_kwargs, "zero_init_output": True}
# calculate layer block order # calculate layer block order
@ -837,23 +866,23 @@ class AttentionLayers(nn.Module):
layer_types = custom_layers layer_types = custom_layers
elif exists(par_ratio): elif exists(par_ratio):
par_depth = depth * len(default_block) par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range' assert 1 < par_ratio <= par_depth, "par ratio out of range"
default_block = tuple(filter(not_equals('f'), default_block)) default_block = tuple(filter(not_equals("f"), default_block))
par_attn = par_depth // par_ratio par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio' assert len(default_block) <= par_width, "default block is too large for par_ratio"
par_block = default_block + ('f',) * (par_width - len(default_block)) par_block = default_block + ("f",) * (par_width - len(default_block))
par_head = par_block * par_attn par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head)) layer_types = par_head + ("f",) * (par_depth - len(par_head))
elif exists(sandwich_coef): elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' assert sandwich_coef > 0 and sandwich_coef <= depth, "sandwich coefficient should be less than the depth"
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef layer_types = ("a",) * sandwich_coef + default_block * (depth - sandwich_coef) + ("f",) * sandwich_coef
else: else:
layer_types = default_block * depth layer_types = default_block * depth
self.layer_types = layer_types self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals('a'), layer_types))) self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
# calculate token shifting # calculate token shifting
@ -864,15 +893,15 @@ class AttentionLayers(nn.Module):
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
is_last_layer = ind == (len(self.layer_types) - 1) is_last_layer = ind == (len(self.layer_types) - 1)
if layer_type == 'a': if layer_type == "a":
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c': elif layer_type == "c":
layer = Attention(dim, heads=heads, **attn_kwargs) layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f': elif layer_type == "f":
layer = FeedForward(dim, **ff_kwargs) layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer) layer = layer if not macaron else Scale(0.5, layer)
else: else:
raise Exception(f'invalid layer type {layer_type}') raise Exception(f"invalid layer type {layer_type}")
if layer_shift_tokens > 0: if layer_shift_tokens > 0:
shift_range_upper = layer_shift_tokens + 1 shift_range_upper = layer_shift_tokens + 1
@ -885,23 +914,15 @@ class AttentionLayers(nn.Module):
residual_fn = GRUGating if gate_residual else Residual residual_fn = GRUGating if gate_residual else Residual
residual = residual_fn(dim, scale_residual=scale_residual) residual = residual_fn(dim, scale_residual=scale_residual)
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c') layer_uses_qk_norm = use_qk_norm_attn and layer_type in ("a", "c")
pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
norms = nn.ModuleList([ norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm])
pre_branch_norm,
post_branch_norm,
post_main_norm
])
self.layers.append(nn.ModuleList([ self.layers.append(nn.ModuleList([norms, layer, residual]))
norms,
layer,
residual
]))
def forward( def forward(
self, self,
@ -918,9 +939,10 @@ class AttentionLayers(nn.Module):
expected_seq_len=None, expected_seq_len=None,
): ):
assert not (self.cross_attend ^ (exists(context) or exists( assert not (
full_context))), 'context must be passed in if cross_attend is set to True' self.cross_attend ^ (exists(context) or exists(full_context))
assert context is None or full_context is None, 'only one of full_context or context can be provided' ), "context must be passed in if cross_attend is set to True"
assert context is None or full_context is None, "only one of full_context or context can be provided"
hiddens = [] hiddens = []
intermediates = [] intermediates = []
@ -930,24 +952,28 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
norm_args = {} norm_args = {}
if exists(norm_scale_shift_inp): if exists(norm_scale_shift_inp):
norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp norm_args["norm_scale_shift_inp"] = norm_scale_shift_inp
rotary_pos_emb = None rotary_pos_emb = None
if exists(self.rotary_pos_emb): if exists(self.rotary_pos_emb):
if not self.training and self.causal: if not self.training and self.causal:
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" assert (
expected_seq_len is not None
), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
elif expected_seq_len is None: elif expected_seq_len is None:
expected_seq_len = 0 expected_seq_len = 0
seq_len = x.shape[1] seq_len = x.shape[1]
if past_key_values is not None: if past_key_values is not None:
seq_len += past_key_values[0][0].shape[-2] seq_len += past_key_values[0][0].shape[-2]
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]) max_rotary_emb_length = max(
list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]
)
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
present_key_values = [] present_key_values = []
cross_attn_count = 0 cross_attn_count = 0
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
if layer_type == 'a': if layer_type == "a":
layer_mem = mems.pop(0) if mems else None layer_mem = mems.pop(0) if mems else None
residual = x residual = x
@ -957,26 +983,39 @@ class AttentionLayers(nn.Module):
if exists(pre_branch_norm): if exists(pre_branch_norm):
x = pre_branch_norm(x, **norm_args) x = pre_branch_norm(x, **norm_args)
if layer_type == 'a' or layer_type == 'c': if layer_type == "a" or layer_type == "c":
if past_key_values is not None: if past_key_values is not None:
layer_kv = past_key_values.pop(0) layer_kv = past_key_values.pop(0)
layer_past = tuple(s.to(x.device) for s in layer_kv) layer_past = tuple(s.to(x.device) for s in layer_kv)
else: else:
layer_past = None layer_past = None
if layer_type == 'a': if layer_type == "a":
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, out, inter, k, v = block(
prev_attn, layer_mem, layer_past) x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem, layer_past
elif layer_type == 'c': )
elif layer_type == "c":
if exists(full_context): if exists(full_context):
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None, out, inter, k, v = block(
None, prev_attn, None, layer_past) x,
full_context[cross_attn_count],
mask,
context_mask,
None,
None,
None,
prev_attn,
None,
layer_past,
)
else: else:
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) out, inter, k, v = block(
elif layer_type == 'f': x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past
)
elif layer_type == "f":
out = block(x) out = block(x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: if layer_type == "a" or layer_type == "c" and present_key_values is not None:
present_key_values.append((k.detach(), v.detach())) present_key_values.append((k.detach(), v.detach()))
if exists(post_branch_norm): if exists(post_branch_norm):
@ -984,28 +1023,26 @@ class AttentionLayers(nn.Module):
x = residual_fn(out, residual) x = residual_fn(out, residual)
if layer_type in ('a', 'c'): if layer_type in ("a", "c"):
intermediates.append(inter) intermediates.append(inter)
if layer_type == 'a' and self.residual_attn: if layer_type == "a" and self.residual_attn:
prev_attn = inter.pre_softmax_attn prev_attn = inter.pre_softmax_attn
elif layer_type == 'c' and self.cross_residual_attn: elif layer_type == "c" and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn prev_cross_attn = inter.pre_softmax_attn
if exists(post_main_norm): if exists(post_main_norm):
x = post_main_norm(x, **norm_args) x = post_main_norm(x, **norm_args)
if layer_type == 'c': if layer_type == "c":
cross_attn_count += 1 cross_attn_count += 1
if layer_type == 'f': if layer_type == "f":
hiddens.append(x) hiddens.append(x)
if return_hiddens: if return_hiddens:
intermediates = LayerIntermediates( intermediates = LayerIntermediates(
hiddens=hiddens, hiddens=hiddens, attn_intermediates=intermediates, past_key_values=present_key_values
attn_intermediates=intermediates,
past_key_values=present_key_values
) )
return x, intermediates return x, intermediates
@ -1015,13 +1052,13 @@ class AttentionLayers(nn.Module):
class Encoder(AttentionLayers): class Encoder(AttentionLayers):
def __init__(self, **kwargs): def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder' assert "causal" not in kwargs, "cannot set causality on encoder"
super().__init__(causal=False, **kwargs) super().__init__(causal=False, **kwargs)
class Decoder(AttentionLayers): class Decoder(AttentionLayers):
def __init__(self, **kwargs): def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder' assert "causal" not in kwargs, "cannot set causality on decoder"
super().__init__(causal=True, **kwargs) super().__init__(causal=True, **kwargs)
@ -1031,22 +1068,13 @@ class CrossAttender(AttentionLayers):
class ViTransformerWrapper(nn.Module): class ViTransformerWrapper(nn.Module):
def __init__( def __init__(self, *, image_size, patch_size, attn_layers, num_classes=None, dropout=0.0, emb_dropout=0.0):
self,
*,
image_size,
patch_size,
attn_layers,
num_classes=None,
dropout=0.,
emb_dropout=0.
):
super().__init__() super().__init__()
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder"
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' assert image_size % patch_size == 0, "image dimensions must be divisible by the patch size"
dim = attn_layers.dim dim = attn_layers.dim
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2 patch_dim = 3 * patch_size**2
self.patch_size = patch_size self.patch_size = patch_size
@ -1059,20 +1087,16 @@ class ViTransformerWrapper(nn.Module):
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
def forward( def forward(self, img, return_embeddings=False):
self,
img,
return_embeddings=False
):
p = self.patch_size p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
x = self.patch_to_embedding(x) x = self.patch_to_embedding(x)
b, n, _ = x.shape b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
x = torch.cat((cls_tokens, x), dim=1) x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embedding[:, :(n + 1)] x = x + self.pos_embedding[:, : (n + 1)]
x = self.dropout(x) x = self.dropout(x)
x = self.attn_layers(x) x = self.attn_layers(x)
@ -1092,15 +1116,15 @@ class TransformerWrapper(nn.Module):
max_seq_len, max_seq_len,
attn_layers, attn_layers,
emb_dim=None, emb_dim=None,
max_mem_len=0., max_mem_len=0.0,
shift_mem_down=0, shift_mem_down=0,
emb_dropout=0., emb_dropout=0.0,
num_memory_tokens=None, num_memory_tokens=None,
tie_embedding=False, tie_embedding=False,
use_pos_emb=True use_pos_emb=True,
): ):
super().__init__() super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder"
dim = attn_layers.dim dim = attn_layers.dim
emb_dim = default(emb_dim, dim) emb_dim = default(emb_dim, dim)
@ -1110,8 +1134,11 @@ class TransformerWrapper(nn.Module):
self.shift_mem_down = shift_mem_down self.shift_mem_down = shift_mem_down
self.token_emb = nn.Embedding(num_tokens, emb_dim) self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( self.pos_emb = (
use_pos_emb and not attn_layers.has_pos_emb) else always(0) AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
@ -1140,7 +1167,7 @@ class TransformerWrapper(nn.Module):
return_attn=False, return_attn=False,
mems=None, mems=None,
use_cache=False, use_cache=False,
**kwargs **kwargs,
): ):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x) x = self.token_emb(x)
@ -1150,7 +1177,7 @@ class TransformerWrapper(nn.Module):
x = self.project_emb(x) x = self.project_emb(x)
if num_mem > 0: if num_mem > 0:
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
x = torch.cat((mem, x), dim=1) x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens # auto-handle masking after appending memory tokens
@ -1158,7 +1185,7 @@ class TransformerWrapper(nn.Module):
mask = F.pad(mask, (num_mem, 0), value=True) mask = F.pad(mask, (num_mem, 0), value=True)
if self.shift_mem_down and exists(mems): if self.shift_mem_down and exists(mems):
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :]
mems = [*mems_r, *mems_l] mems = [*mems_r, *mems_l]
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
@ -1186,25 +1213,20 @@ class TransformerWrapper(nn.Module):
class ContinuousTransformerWrapper(nn.Module): class ContinuousTransformerWrapper(nn.Module):
def __init__( def __init__(
self, self, *, max_seq_len, attn_layers, dim_in=None, dim_out=None, emb_dim=None, emb_dropout=0.0, use_pos_emb=True
*,
max_seq_len,
attn_layers,
dim_in=None,
dim_out=None,
emb_dim=None,
emb_dropout=0.,
use_pos_emb=True
): ):
super().__init__() super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder"
dim = attn_layers.dim dim = attn_layers.dim
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if ( self.pos_emb = (
use_pos_emb and not attn_layers.has_pos_emb) else always(0) AbsolutePositionalEmbedding(dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout) self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
@ -1214,16 +1236,7 @@ class ContinuousTransformerWrapper(nn.Module):
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward( def forward(self, x, return_embeddings=False, mask=None, return_attn=False, mems=None, use_cache=False, **kwargs):
self,
x,
return_embeddings=False,
mask=None,
return_attn=False,
mems=None,
use_cache=False,
**kwargs
):
b, n, _, device = *x.shape, x.device b, n, _, device = *x.shape, x.device
x = self.project_in(x) x = self.project_in(x)
@ -1245,4 +1258,3 @@ class ContinuousTransformerWrapper(nn.Module):
if len(res) > 1: if len(res) > 1:
return tuple(res) return tuple(res)
return res[0] return res[0]

View File

@ -1,31 +0,0 @@
from torch import nn
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock
class FramePriorNet(nn.Module):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, num_res_blocks=13, num_conv_blocks=2
):
super().__init__()
self.res_blocks = nn.ModuleList()
for idx in range(num_res_blocks):
block = Conv1dBNBlock(
in_channels if idx == 0 else hidden_channels,
out_channels if (idx + 1) == num_res_blocks else hidden_channels,
hidden_channels,
kernel_size,
1,
num_conv_blocks,
)
self.res_blocks.append(block)
def forward(self, x, x_mask=None):
if x_mask is None:
x_mask = 1.0
o = x * x_mask
for block in self.res_blocks:
res = o
o = block(o)
o = o + res
if x_mask is not None:
o = o * x_mask
return o

View File

@ -1,91 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, num_mel, filter):
super().__init__()
self.num_mel = num_mel
start_index = 2
end_index = filter / 16
i = start_index
filt_len = []
while i <= end_index:
i = i * 2
filt_len.append(i)
filt_len.append(i)
filters = [1] + filt_len
num_layers = len(filters) - 1
convs = [
nn.Conv2d(
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
)
for i in range(num_layers)
]
self.convs = nn.ModuleList(convs)
self.training = False
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
self.recurrence = nn.LSTM(
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
)
def forward(self, inputs, input_lengths):
batch_size = inputs.size(0)
x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel]
valid_lengths = input_lengths.float() # [batch_size]
for conv, bn in zip(self.convs, self.bns):
x = conv(x)
x = bn(x)
x = F.relu(x)
# Create the post conv width mask based on the valid lengths of the output of the convolution.
# The valid lengths for the output of a convolution on varying length inputs is
# ceil(input_length/stride) + 1 for stride=3 and padding=2
# For example (kernel_size=3, stride=2, padding=2):
# 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d
# _____
# x _____
# x _____
# x ____
# x
# x x x x -> Output valid length = 4
# Since every example in te batch is zero padded and therefore have separate valid_lengths,
# we need to mask off all the values AFTER the valid length for each example in the batch.
# Otherwise, the convolutions create noise and a lot of not real information
valid_lengths = (valid_lengths / 2).float()
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size]
post_conv_max_width = x.size(2)
mask = torch.arange(post_conv_max_width).to(inputs.device).expand(
len(valid_lengths), post_conv_max_width
) < valid_lengths.unsqueeze(1)
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1]
x = x * mask
x = x.transpose(1, 2)
# x: 4D tensor [batch_size, post_conv_width,
# num_channels==128, post_conv_height]
post_conv_width = x.size(1)
x = x.contiguous().view(batch_size, post_conv_width, -1)
# x: 3D tensor [batch_size, post_conv_width,
# num_channels*post_conv_height]
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
post_conv_input_lengths = valid_lengths
packed_seqs = nn.utils.rnn.pack_padded_sequence(
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.recurrence.flatten_parameters()
_, (ht, _) = self.recurrence(packed_seqs)
last_output = ht[-1]
return last_output.to(inputs.device) # [B, 128]

View File

@ -1,67 +0,0 @@
import torch
from torch.autograd import Function
class VectorQuantization(Function):
@staticmethod
def forward(ctx, inputs, codebook):
with torch.no_grad():
embedding_size = codebook.size(1)
inputs_size = inputs.size()
inputs_flatten = inputs.view(-1, embedding_size)
codebook_sqr = torch.sum(codebook ** 2, dim=1)
inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)
# Compute the distances to the codebook
distances = torch.addmm(codebook_sqr + inputs_sqr,
inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
_, indices_flatten = torch.min(distances, dim=1)
indices = indices_flatten.view(*inputs_size[:-1])
ctx.mark_non_differentiable(indices)
return indices
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError('Trying to call `.grad()` on graph containing '
'`VectorQuantization`. The function `VectorQuantization` '
'is not differentiable. Use `VectorQuantizationStraightThrough` '
'if you want a straight-through estimator of the gradient.')
class VectorQuantizationStraightThrough(Function):
@staticmethod
def forward(ctx, inputs, codebook):
indices = vq(inputs, codebook)
indices_flatten = indices.view(-1)
ctx.save_for_backward(indices_flatten, codebook)
ctx.mark_non_differentiable(indices_flatten)
codes_flatten = torch.index_select(codebook, dim=0,
index=indices_flatten)
codes = codes_flatten.view_as(inputs)
return (codes, indices_flatten)
@staticmethod
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
# Straight-through estimator
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
embedding_size = codebook.size(1)
grad_output_flatten = (grad_output.contiguous()
.view(-1, embedding_size))
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output_flatten)
return (grad_inputs, grad_codebook)
vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply
__all__ = [vq, vq_st]

View File

@ -1,305 +0,0 @@
import copy
from abc import abstractmethod
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
from torch import nn
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2"""
def __init__(
self,
config: "TacotronConfig",
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields as class attributes
for key in config:
setattr(self, key, config[key])
# layers
self.embedding = None
self.encoder = None
self.decoder = None
self.postnet = None
# init tensors
self.embedded_speakers = None
self.embedded_speakers_projected = None
# global style token
if self.gst and self.use_gst:
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
self.gst_layer = None
# Capacitron
if self.capacitron_vae and self.use_capacitron_vae:
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim
self.capacitron_vae_layer = None
# additional layers
self.decoder_backward = None
self.coarse_decoder = None
@staticmethod
def _format_aux_input(aux_input: Dict) -> Dict:
"""Set missing fields to their default values"""
if aux_input:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
return None
#############################
# INIT FUNCTIONS
#############################
def _init_backward_decoder(self):
"""Init the backward decoder for Forward-Backward decoding."""
self.decoder_backward = copy.deepcopy(self.decoder)
def _init_coarse_decoder(self):
"""Init the coarse decoder for Double-Decoder Consistency."""
self.coarse_decoder = copy.deepcopy(self.decoder)
self.coarse_decoder.r_init = self.ddc_r
self.coarse_decoder.set_r(self.ddc_r)
#############################
# CORE FUNCTIONS
#############################
@abstractmethod
def forward(self):
pass
@abstractmethod
def inference(self):
pass
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load model checkpoint and set up internals.
Args:
config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file.
eval (bool, optional): whether to load model for evaluation.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
# TODO: set r in run-time by taking it from the new config
if "r" in state:
# set r from the state (for compatibility with older checkpoints)
self.decoder.set_r(state["r"])
elif "config" in state:
# set r from config used at training time (for inference)
self.decoder.set_r(state["config"]["r"])
else:
# set r from the new config (for new-models)
self.decoder.set_r(config.r)
if eval:
self.eval()
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
assert not self.training
def get_criterion(self) -> nn.Module:
"""Get the model criterion used in training."""
return TacotronLoss(self.config)
@staticmethod
def init_from_config(config: Coqpit):
"""Initialize model from config."""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return BaseTacotron(config, ap, tokenizer, speaker_manager)
##########################
# TEST AND LOG FUNCTIONS #
##########################
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Args:
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_test_aux_input()
for idx, sen in enumerate(test_sentences):
outputs_dict = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False
)
return {"figures": test_figures, "audios": test_audios}
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
logger.test_figures(steps, outputs["figures"])
#############################
# COMMON COMPUTE FUNCTIONS
#############################
def compute_masks(self, text_lengths, mel_lengths):
"""Compute masks against sequence paddings."""
# B x T_in_max (boolean)
input_mask = sequence_mask(text_lengths)
output_mask = None
if mel_lengths is not None:
max_len = mel_lengths.max()
r = self.decoder.r
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
output_mask = sequence_mask(mel_lengths, max_len=max_len)
return input_mask, output_mask
def _backward_pass(self, mel_specs, encoder_outputs, mask):
"""Run backwards decoder"""
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
return decoder_outputs_b, alignments_b
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
"""Double Decoder Consistency"""
T = mel_specs.shape[1]
if T % self.coarse_decoder.r > 0:
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
encoder_outputs.detach(), mel_specs, input_mask
)
# scale_factor = self.decoder.r_init / self.decoder.r
alignments_backward = torch.nn.functional.interpolate(
alignments_backward.transpose(1, 2),
size=alignments.shape[1],
mode="nearest",
).transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
return decoder_outputs_backward, alignments_backward
#############################
# EMBEDDING FUNCTIONS
#############################
def compute_gst(self, inputs, style_input, speaker_embedding=None):
"""Compute global style token"""
if isinstance(style_input, dict):
# multiply each style token with a weight
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
# ignore style token and return zero tensor
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
else:
# compute style tokens
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
return inputs
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
"""Capacitron Variational Autoencoder"""
(
VAE_outputs,
posterior_distribution,
prior_distribution,
capacitron_beta,
) = self.capacitron_vae_layer(
reference_mel_info,
text_info,
speaker_embedding, # pylint: disable=not-callable
)
VAE_outputs = VAE_outputs.to(inputs.device)
encoder_output = self._concat_speaker_embedding(
inputs, VAE_outputs
) # concatenate to the output of the basic tacotron encoder
return (
encoder_output,
posterior_distribution,
prior_distribution,
capacitron_beta,
)
@staticmethod
def _add_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
outputs = outputs + embedded_speakers_
return outputs
@staticmethod
def _concat_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
outputs = torch.cat([outputs, embedded_speakers_], dim=-1)
return outputs
#############################
# CALLBACKS
#############################
def on_epoch_start(self, trainer):
"""Callback for setting values wrt gradual training schedule.
Args:
trainer (TrainerTTS): TTS trainer object that is used to train this model.
"""
if self.gradual_training:
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
trainer.config.r = r
self.decoder.set_r(r)
if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}")

View File

@ -2,12 +2,12 @@
import os import os
import random import random
from contextlib import contextmanager
from time import time from time import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from tqdm import tqdm from tqdm import tqdm
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
@ -16,22 +16,14 @@ from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
from TTS.tts.layers.tortoise.clvp import CLVP from TTS.tts.layers.tortoise.clvp import CLVP
from TTS.tts.layers.tortoise.cvvp import CVVP from TTS.tts.layers.tortoise.cvvp import CVVP
from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.diffusion import (
SpacedDiffusion,
get_named_beta_schedule,
space_timesteps,
)
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
from contextlib import contextmanager
def pad_or_truncate(t, length): def pad_or_truncate(t, length):
""" """
@ -56,9 +48,7 @@ def load_discrete_vocoder_diffuser(
Helper function to load a GaussianDiffusion instance configured for use as a vocoder. Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
""" """
return SpacedDiffusion( return SpacedDiffusion(
use_timesteps=space_timesteps( use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
trained_diffusion_steps, [desired_diffusion_steps]
),
model_mean_type="epsilon", model_mean_type="epsilon",
model_var_type="learned_range", model_var_type="learned_range",
loss_type="mse", loss_type="mse",
@ -141,7 +131,7 @@ def do_spectrogram_diffusion(
output_shape, output_shape,
noise=noise, noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=verbose progress=verbose,
) )
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
@ -166,9 +156,7 @@ def classify_audio_clip(clip):
kernel_size=5, kernel_size=5,
distribute_zero_label=False, distribute_zero_label=False,
) )
classifier.load_state_dict( classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu")))
torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu"))
)
clip = clip.cpu().unsqueeze(0) clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1) results = F.softmax(classifier(clip), dim=-1)
return results[0][0] return results[0][0]
@ -238,9 +226,7 @@ class TextToSpeech:
self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed
self.models_dir = models_dir self.models_dir = models_dir
self.autoregressive_batch_size = ( self.autoregressive_batch_size = (
pick_best_batch_size_for_gpu() pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
if autoregressive_batch_size is None
else autoregressive_batch_size
) )
self.enable_redaction = enable_redaction self.enable_redaction = enable_redaction
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -274,9 +260,7 @@ class TextToSpeech:
self.autoregressive.load_state_dict(torch.load(ar_path)) self.autoregressive.load_state_dict(torch.load(ar_path))
self.autoregressive.post_init_gpt2_config(kv_cache) self.autoregressive.post_init_gpt2_config(kv_cache)
diff_path = diff_checkpoint or get_model_path( diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
"diffusion_decoder.pth", models_dir
)
self.diffusion = ( self.diffusion = (
DiffusionTts( DiffusionTts(
model_channels=1024, model_channels=1024,
@ -365,9 +349,7 @@ class TextToSpeech:
.cpu() .cpu()
.eval() .eval()
) )
self.cvvp.load_state_dict( self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
torch.load(get_model_path("cvvp.pth", self.models_dir))
)
def get_conditioning_latents( def get_conditioning_latents(
self, self,
@ -407,11 +389,7 @@ class TextToSpeech:
DURS_CONST = 102400 DURS_CONST = 102400
for ls in voice_samples: for ls in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs) # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = ( sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
torchaudio.functional.resample(ls[0], 22050, 24000)
if original_tortoise
else ls[1]
)
if latent_averaging_mode == 0: if latent_averaging_mode == 0:
sample = pad_or_truncate(sample, DURS_CONST) sample = pad_or_truncate(sample, DURS_CONST)
cond_mel = wav_to_univnet_mel( cond_mel = wav_to_univnet_mel(
@ -426,9 +404,7 @@ class TextToSpeech:
if latent_averaging_mode == 2: if latent_averaging_mode == 2:
temp_diffusion_conds = [] temp_diffusion_conds = []
for chunk in range(ceil(sample.shape[1] / DURS_CONST)): for chunk in range(ceil(sample.shape[1] / DURS_CONST)):
current_sample = sample[ current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST]
:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST
]
current_sample = pad_or_truncate(current_sample, DURS_CONST) current_sample = pad_or_truncate(current_sample, DURS_CONST)
cond_mel = wav_to_univnet_mel( cond_mel = wav_to_univnet_mel(
current_sample.to(self.device), current_sample.to(self.device),
@ -440,9 +416,7 @@ class TextToSpeech:
elif latent_averaging_mode == 2: elif latent_averaging_mode == 2:
temp_diffusion_conds.append(cond_mel) temp_diffusion_conds.append(cond_mel)
if latent_averaging_mode == 2: if latent_averaging_mode == 2:
diffusion_conds.append( diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0))
torch.stack(temp_diffusion_conds).mean(0)
)
diffusion_conds = torch.stack(diffusion_conds, dim=1) diffusion_conds = torch.stack(diffusion_conds, dim=1)
with self.temporary_cuda(self.diffusion) as diffusion: with self.temporary_cuda(self.diffusion) as diffusion:
@ -471,9 +445,7 @@ class TextToSpeech:
) )
) )
with torch.no_grad(): with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion( return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
torch.tensor([0.0])
)
def tts_with_preset(self, text, preset="fast", **kwargs): def tts_with_preset(self, text, preset="fast", **kwargs):
""" """
@ -521,10 +493,7 @@ class TextToSpeech:
"diffusion_iterations": 50, "diffusion_iterations": 50,
"sampler": "ddim", "sampler": "ddim",
}, },
"fast_old": { "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
"num_autoregressive_samples": 96,
"diffusion_iterations": 80
},
"standard": { "standard": {
"num_autoregressive_samples": 256, "num_autoregressive_samples": 256,
"diffusion_iterations": 200, "diffusion_iterations": 200,
@ -618,9 +587,7 @@ class TextToSpeech:
""" """
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
text_tokens = ( text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert ( assert (
text_tokens.shape[-1] < 400 text_tokens.shape[-1] < 400
@ -628,12 +595,7 @@ class TextToSpeech:
auto_conds = None auto_conds = None
if voice_samples is not None: if voice_samples is not None:
( (auto_conditioning, diffusion_conditioning, auto_conds, _,) = self.get_conditioning_latents(
auto_conditioning,
diffusion_conditioning,
auto_conds,
_,
) = self.get_conditioning_latents(
voice_samples, voice_samples,
return_mels=True, return_mels=True,
latent_averaging_mode=latent_averaging_mode, latent_averaging_mode=latent_averaging_mode,
@ -650,10 +612,7 @@ class TextToSpeech:
diffusion_conditioning = diffusion_conditioning.to(self.device) diffusion_conditioning = diffusion_conditioning.to(self.device)
diffuser = load_discrete_vocoder_diffuser( diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=diffusion_iterations, desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler
cond_free=cond_free,
cond_free_k=cond_free_k,
sampler=sampler
) )
# in the case of single_sample, # in the case of single_sample,
@ -664,13 +623,13 @@ class TextToSpeech:
samples = [] samples = []
num_batches = num_autoregressive_samples // self.autoregressive_batch_size num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.stop_mel_token stop_mel_token = self.autoregressive.stop_mel_token
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" calm_token = (
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
)
self.autoregressive = self.autoregressive.to(self.device) self.autoregressive = self.autoregressive.to(self.device)
if verbose: if verbose:
print("Generating autoregressive samples..") print("Generating autoregressive samples..")
with self.temporary_cuda( with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(
self.autoregressive
) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=half device_type="cuda", dtype=torch.float16, enabled=half
): ):
for b in tqdm(range(num_batches), disable=not verbose): for b in tqdm(range(num_batches), disable=not verbose):
@ -689,9 +648,7 @@ class TextToSpeech:
padding_needed = max_mel_tokens - codes.shape[1] padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes) samples.append(codes)
self.autoregressive_batch_size = ( self.autoregressive_batch_size = orig_batch_size # in the case of single_sample
orig_batch_size # in the case of single_sample
)
clip_results = [] clip_results = []
with self.temporary_cuda(self.clvp) as clvp, torch.autocast( with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
@ -729,9 +686,7 @@ class TextToSpeech:
if cvvp_amount == 1: if cvvp_amount == 1:
clip_results.append(cvvp) clip_results.append(cvvp)
else: else:
clip_results.append( clip_results.append(cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount))
cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount)
)
else: else:
clip_results.append(clvp_res) clip_results.append(clvp_res)
clip_results = torch.cat(clip_results, dim=0) clip_results = torch.cat(clip_results, dim=0)
@ -744,19 +699,14 @@ class TextToSpeech:
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage. # results, but will increase memory usage.
with self.temporary_cuda( with self.temporary_cuda(self.autoregressive) as autoregressive:
self.autoregressive
) as autoregressive:
best_latents = autoregressive( best_latents = autoregressive(
auto_conditioning.repeat(k, 1), auto_conditioning.repeat(k, 1),
text_tokens.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
best_results, best_results,
torch.tensor( torch.tensor(
[ [best_results.shape[-1] * self.autoregressive.mel_length_compression],
best_results.shape[-1]
* self.autoregressive.mel_length_compression
],
device=text_tokens.device, device=text_tokens.device,
), ),
return_latent=True, return_latent=True,
@ -778,9 +728,7 @@ class TextToSpeech:
ctokens += 1 ctokens += 1
else: else:
ctokens = 0 ctokens = 0
if ( if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
ctokens > 8
): # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k] latents = latents[:, :k]
break break
with self.temporary_cuda(self.diffusion) as diffusion: with self.temporary_cuda(self.diffusion) as diffusion:
@ -801,10 +749,7 @@ class TextToSpeech:
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
return clip return clip
wav_candidates = [ wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
potentially_redact(wav_candidate, text)
for wav_candidate in wav_candidates
]
if len(wav_candidates) > 1: if len(wav_candidates) > 1:
res = wav_candidates res = wav_candidates

View File

@ -97,7 +97,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.mel_norm = mel_norm self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None self.mel_basis = None
self.normalized=normalized self.normalized = normalized
if use_mel: if use_mel:
self._build_mel_basis() self._build_mel_basis()