style fix

This commit is contained in:
manmay-nakhashi 2023-04-22 17:49:47 +05:30
parent b892aa925a
commit 18c745ceef
22 changed files with 723 additions and 1605 deletions

View File

@ -1,9 +1,8 @@
import os
import functools
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
@ -11,6 +10,7 @@ from transformers import LogitsWarper
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
def zero_module(module):
"""
Zero out the parameters of a module and return it.
@ -64,11 +64,11 @@ class QKVAttentionLegacy(nn.Module):
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
if rel_pos is not None:
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
@ -112,7 +112,13 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
self.relative_pos_embeddings = RelativePositionBias(
scale=(channels // self.num_heads) ** 0.5,
causal=False,
heads=num_heads,
num_buckets=32,
max_distance=64,
)
else:
self.relative_pos_embeddings = None
@ -168,9 +174,7 @@ class Downsample(nn.Module):
stride = factor
if use_conv:
self.op = nn.Conv1d(
self.channels, self.out_channels, ksize, stride=stride, padding=pad
)
self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad)
else:
assert self.channels == self.out_channels
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
@ -221,17 +225,13 @@ class ResBlock(nn.Module):
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
),
zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = nn.Conv1d(
channels, self.out_channels, kernel_size, padding=padding
)
self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding)
else:
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
@ -249,7 +249,8 @@ class ResBlock(nn.Module):
class AudioMiniEncoder(nn.Module):
def __init__(self,
def __init__(
self,
spec_dim,
embedding_dim,
base_channels=128,
@ -259,27 +260,27 @@ class AudioMiniEncoder(nn.Module):
num_attn_heads=4,
dropout=0,
downsample_factor=2,
kernel_size=3):
kernel_size=3,
):
super().__init__()
self.init = nn.Sequential(
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
)
self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
ch = base_channels
res = []
for l in range(depth):
for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
ch *= 2
self.res = nn.Sequential(*res)
self.final = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.Conv1d(ch, embedding_dim, 1)
)
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
attn = []
for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
attn.append(
AttentionBlock(
embedding_dim,
num_attn_heads,
)
)
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim
@ -291,15 +292,24 @@ class AudioMiniEncoder(nn.Module):
return h[:, :, 0]
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../utils/assets/tortoise/mel_norms.pth')
DEFAULT_MEL_NORM_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/mel_norms.pth"
)
class TorchMelSpectrogram(nn.Module):
def __init__(self, filter_length=1024, hop_length=256,
win_length=1024, n_mel_channels=80,
mel_fmin=0, mel_fmax=8000,
sampling_rate=22050, normalize=False,
mel_norm_file=DEFAULT_MEL_NORM_FILE):
def __init__(
self,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=80,
mel_fmin=0,
mel_fmax=8000,
sampling_rate=22050,
normalize=False,
mel_norm_file=DEFAULT_MEL_NORM_FILE,
):
super().__init__()
# These are the default tacotron values for the MEL spectrogram.
self.filter_length = filter_length
@ -309,7 +319,8 @@ class TorchMelSpectrogram(nn.Module):
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.sampling_rate = sampling_rate
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length,
self.mel_stft = torchaudio.transforms.MelSpectrogram(
n_fft=self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
power=2,
@ -318,7 +329,8 @@ class TorchMelSpectrogram(nn.Module):
f_min=self.mel_fmin,
f_max=self.mel_fmax,
n_mels=self.n_mel_channels,
norm="slaney")
norm="slaney",
)
self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None:
self.mel_norms = torch.load(self.mel_norm_file)
@ -326,7 +338,9 @@ class TorchMelSpectrogram(nn.Module):
self.mel_norms = None
def forward(self, inp):
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
if (
len(inp.shape) == 3
): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1)
assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device)
@ -344,6 +358,7 @@ class CheckpointedLayer(nn.Module):
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args.
"""
def __init__(self, wrap):
super().__init__()
self.wrap = wrap
@ -360,6 +375,7 @@ class CheckpointedXTransformerEncoder(nn.Module):
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects.
"""
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
super().__init__()
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
@ -374,10 +390,10 @@ class CheckpointedXTransformerEncoder(nn.Module):
def forward(self, x, **kwargs):
if self.needs_permute:
x = x.permute(0,2,1)
x = x.permute(0, 2, 1)
h = self.transformer(x, **kwargs)
if self.exit_permute:
h = h.permute(0,2,1)
h = h.permute(0, 2, 1)
return h
@ -392,9 +408,7 @@ class TypicalLogitsWarper(LogitsWarper):
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
@ -409,15 +423,11 @@ class TypicalLogitsWarper(LogitsWarper):
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
1, last_ind.view(-1, 1)
)
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores

View File

@ -7,11 +7,10 @@ import numpy as np
import torch
import torchaudio
from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT
BUILTIN_VOICES_DIR = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices"
)
BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices")
def load_wav_to_torch(full_path):
@ -58,10 +57,7 @@ def read_audio_file(audiopath: str):
def load_required_audio(audiopath: str):
audio, lsr = read_audio_file(audiopath)
audios = [
torchaudio.functional.resample(audio, lsr, sampling_rate)
for sampling_rate in (22050, 24000)
]
audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)]
for audio in audios:
check_audio(audio, audiopath)
@ -83,9 +79,7 @@ TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(norm_mel):
return ((norm_mel + 1) / 2) * (
TACOTRON_MEL_MAX - TACOTRON_MEL_MIN
) + TACOTRON_MEL_MIN
return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel):
@ -118,11 +112,7 @@ def get_voices(extra_voice_dirs: List[str] = []):
for sub in subs:
subj = os.path.join(d, sub)
if os.path.isdir(subj):
voices[sub] = (
list(glob(f"{subj}/*.wav"))
+ list(glob(f"{subj}/*.mp3"))
+ list(glob(f"{subj}/*.pth"))
)
voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth"))
return voices
@ -148,9 +138,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
for voice in voices:
if voice == "random":
if len(voices) > 1:
print(
"Cannot combine a random voice with a non-random voice. Just using a random voice."
)
print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
return None, None
clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None:
@ -171,15 +159,18 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
latents = (latents_0, latents_1)
return None, latents
def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
stft = TorchSTFT(n_fft=1024,
stft = TorchSTFT(
n_fft=1024,
hop_length=256,
win_length=1024,
use_mel=True,
n_mels=100,
sample_rate=24000,
mel_fmin=0,
mel_fmax=12000)
mel_fmax=12000,
)
stft = stft.to(device)
mel = stft(wav)
mel = dynamic_range_compression(mel)

View File

@ -9,6 +9,7 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
@ -98,9 +99,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model.
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Create embedding
mel_len = self.cached_mel_emb.shape[1]
@ -109,9 +108,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
text_emb = self.embeddings(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave(
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
)
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0)
else: # this outcome only occurs once per loop in most cases
mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1)
@ -158,10 +155,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@ -210,9 +204,7 @@ class LearnedPositionEmbeddings(nn.Module):
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
def build_hf_gpt_transformer(
layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing
):
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
"""
GPT-2 implemented by the HuggingFace library.
"""
@ -230,9 +222,7 @@ def build_hf_gpt_transformer(
)
gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings
del (
gpt.wpe
) # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused.
del gpt.wte
@ -251,21 +241,15 @@ class MelEncoder(nn.Module):
self.channels = channels
self.encoder = nn.Sequential(
nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
nn.Sequential(
*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]
),
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 16, channels // 2),
nn.ReLU(),
nn.Sequential(
*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]
),
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 8, channels),
nn.ReLU(),
nn.Sequential(
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]
),
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
)
self.reduction = 4
@ -317,9 +301,7 @@ class UnifiedVoice(nn.Module):
super().__init__()
self.number_text_tokens = number_text_tokens
self.start_text_token = (
number_text_tokens * types if start_text_token is None else start_text_token
)
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
self.stop_text_token = 0
self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token
@ -331,12 +313,8 @@ class UnifiedVoice(nn.Module):
self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder(
80, model_dim, num_attn_heads=heads
)
self.text_embedding = nn.Embedding(
self.number_text_tokens * types + 1, model_dim
)
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else:
@ -356,12 +334,8 @@ class UnifiedVoice(nn.Module):
checkpointing,
)
if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter(
torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
)
self.text_solo_embedding = nn.Parameter(
torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
)
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
else:
self.mel_solo_embedding = 0
self.text_solo_embedding = 0
@ -414,9 +388,7 @@ class UnifiedVoice(nn.Module):
preformatting to create a working TTS model.
"""
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = torch.div(
wav_lengths, self.mel_length_compression, rounding_mode="trunc"
)
mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc")
for b in range(len(mel_lengths)):
actual_end = (
mel_lengths[b] + 1
@ -436,31 +408,22 @@ class UnifiedVoice(nn.Module):
return_latent=False,
):
if second_inputs is not None:
emb = torch.cat(
[speech_conditioning_inputs, first_inputs, second_inputs], dim=1
)
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt(
inputs_embeds=emb, return_dict=True, output_attentions=get_attns
)
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
if get_attns:
return gpt_out.attentions
enc = gpt_out.last_hidden_state[
:, 1:
] # The first logit is tied to the speech_conditioning_input
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc)
if return_latent:
return (
enc[
:,
speech_conditioning_inputs.shape[
1
] : speech_conditioning_inputs.shape[1]
+ first_inputs.shape[1],
speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1],
],
enc[:, -second_inputs.shape[1] :],
)
@ -539,9 +502,7 @@ class UnifiedVoice(nn.Module):
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token
)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
text_inputs
)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
mel_codes, self.start_mel_token, self.stop_mel_token
)
@ -596,15 +557,13 @@ class UnifiedVoice(nn.Module):
max_generate_length=None,
typical_sampling=False,
typical_mass=0.9,
**hf_generate_kwargs
**hf_generate_kwargs,
):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token
)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
text_inputs
)
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
conds = speech_conditioning_latent.unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1)
@ -628,20 +587,14 @@ class UnifiedVoice(nn.Module):
num_return_sequences % input_tokens.shape[0] == 0
), "The number of return sequences must be divisible by the number of input sequences"
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
input_tokens = input_tokens.repeat(
num_return_sequences // input_tokens.shape[0], 1
)
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
logits_processor = (
LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)])
if typical_sampling
else LogitsProcessorList()
LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
) # TODO disable this
max_length = (
trunc_index + self.max_mel_tokens - 1
if max_generate_length is None
else trunc_index + max_generate_length
trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
)
gen = self.inference_model.generate(
inputs,
@ -651,7 +604,7 @@ class UnifiedVoice(nn.Module):
max_length=max_length,
logits_processor=logits_processor,
num_return_sequences=num_return_sequences,
**hf_generate_kwargs
**hf_generate_kwargs,
)
return gen[:, trunc_index:]

View File

@ -1,13 +1,7 @@
import torch
import torch.nn as nn
from TTS.tts.layers.tortoise.arch_utils import (
AttentionBlock,
Downsample,
Upsample,
normalization,
zero_module,
)
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module
class ResBlock(nn.Module):
@ -54,19 +48,13 @@ class ResBlock(nn.Module):
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv1d(
self.out_channels, self.out_channels, kernel_size, padding=padding
)
),
zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = nn.Conv1d(
dims, channels, self.out_channels, kernel_size, padding=padding
)
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding)
else:
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
@ -104,24 +92,14 @@ class AudioMiniEncoder(nn.Module):
self.layers = depth
for l in range(depth):
for r in range(resnet_blocks):
res.append(
ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)
)
res.append(
Downsample(
ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
)
)
res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
ch *= 2
self.res = nn.Sequential(*res)
self.final = nn.Sequential(
normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
)
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
attn = []
for a in range(attn_blocks):
attn.append(
AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)
)
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
self.attn = nn.Sequential(*attn)
self.dim = embedding_dim

View File

@ -12,9 +12,10 @@ def exists(val):
return val is not None
def masked_mean(t, mask, dim = 1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
def masked_mean(t, mask, dim=1):
t = t.masked_fill(~mask[:, :, None], 0.0)
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
class CLVP(nn.Module):
"""
@ -59,13 +60,14 @@ class CLVP(nn.Module):
dim=dim_text,
depth=text_enc_depth,
heads=text_heads,
ff_dropout=.1,
ff_dropout=0.1,
ff_mult=2,
attn_dropout=.1,
attn_dropout=0.1,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
))
),
)
self.speech_transformer = CheckpointedXTransformerEncoder(
needs_permute=False,
exit_permute=False,
@ -74,20 +76,23 @@ class CLVP(nn.Module):
dim=dim_speech,
depth=speech_enc_depth,
heads=speech_heads,
ff_dropout=.1,
ff_dropout=0.1,
ff_mult=2,
attn_dropout=.1,
attn_dropout=0.1,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
))
),
)
else:
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
heads=text_heads)
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
depth=speech_enc_depth, heads=speech_heads)
self.text_transformer = Transformer(
causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads
)
self.speech_transformer = Transformer(
causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads
)
self.temperature = nn.Parameter(torch.tensor(1.))
self.temperature = nn.Parameter(torch.tensor(1.0))
self.text_mask_percentage = text_mask_percentage
self.voice_mask_percentage = voice_mask_percentage
self.wav_token_compression = wav_token_compression
@ -96,12 +101,7 @@ class CLVP(nn.Module):
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
def forward(
self,
text,
speech_tokens,
return_loss=False
):
def forward(self, text, speech_tokens, return_loss=False):
b, device = text.shape[0], text.device
if self.training:
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
@ -131,25 +131,29 @@ class CLVP(nn.Module):
temp = self.temperature.exp()
if not return_loss:
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp
return sim
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp
labels = torch.arange(b, device=device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss
if __name__ == '__main__':
clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2)
clip(torch.randint(0,256,(2,120)),
torch.tensor([50,100]),
torch.randint(0,8192,(2,250)),
torch.tensor([101,102]),
return_loss=True)
nonloss = clip(torch.randint(0,256,(2,120)),
torch.tensor([50,100]),
torch.randint(0,8192,(2,250)),
torch.tensor([101,102]),
return_loss=False)
if __name__ == "__main__":
clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2)
clip(
torch.randint(0, 256, (2, 120)),
torch.tensor([50, 100]),
torch.randint(0, 8192, (2, 250)),
torch.tensor([101, 102]),
return_loss=True,
)
nonloss = clip(
torch.randint(0, 256, (2, 120)),
torch.tensor([50, 100]),
torch.randint(0, 8192, (2, 250)),
torch.tensor([101, 102]),
return_loss=False,
)
print(nonloss.shape)

View File

@ -17,16 +17,7 @@ def masked_mean(t, mask):
class CollapsingTransformer(nn.Module):
def __init__(
self,
model_dim,
output_dims,
heads,
dropout,
depth,
mask_percentage=0,
**encoder_kwargs
):
def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
super().__init__()
self.transformer = ContinuousTransformerWrapper(
max_seq_len=-1,
@ -105,9 +96,7 @@ class CVVP(nn.Module):
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
if mel_codes is None:
self.speech_emb = nn.Conv1d(
mel_channels, model_dim, kernel_size=5, padding=2
)
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer(
@ -135,9 +124,7 @@ class CVVP(nn.Module):
enc_speech = self.speech_transformer(speech_emb)
speech_latents = self.to_speech_latent(enc_speech)
cond_latents, speech_latents = map(
lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)
)
cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents))
temp = self.temperature.exp()
if not return_loss:

View File

@ -13,8 +13,8 @@ import math
import numpy as np
import torch
import torch as th
from tqdm import tqdm
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
from tqdm import tqdm
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
@ -38,18 +38,9 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
def approx_standard_normal_cdf(x):
@ -112,9 +103,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
@ -149,9 +138,9 @@ class ModelMeanType(enum.Enum):
Which type of output the model predicts.
"""
PREVIOUS_X = 'previous_x' # the model predicts x_{t-1}
START_X = 'start_x' # the model predicts x_0
EPSILON = 'epsilon' # the model predicts epsilon
PREVIOUS_X = "previous_x" # the model predicts x_{t-1}
START_X = "start_x" # the model predicts x_0
EPSILON = "epsilon" # the model predicts epsilon
class ModelVarType(enum.Enum):
@ -162,17 +151,17 @@ class ModelVarType(enum.Enum):
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
LEARNED = 'learned'
FIXED_SMALL = 'fixed_small'
FIXED_LARGE = 'fixed_large'
LEARNED_RANGE = 'learned_range'
LEARNED = "learned"
FIXED_SMALL = "fixed_small"
FIXED_LARGE = "fixed_large"
LEARNED_RANGE = "learned_range"
class LossType(enum.Enum):
MSE = 'mse' # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances)
KL = 'kl' # use the variational lower-bound
RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB
MSE = "mse" # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances)
KL = "kl" # use the variational lower-bound
RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB
def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL
@ -239,22 +228,12 @@ class GaussianDiffusion:
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
def q_mean_variance(self, x_start, t):
"""
@ -264,13 +243,9 @@ class GaussianDiffusion:
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None):
@ -289,8 +264,7 @@ class GaussianDiffusion:
assert noise.shape == x_start.shape
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
@ -306,9 +280,7 @@ class GaussianDiffusion:
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
@ -317,9 +289,7 @@ class GaussianDiffusion:
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
):
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
@ -358,9 +328,7 @@ class GaussianDiffusion:
model_log_variance = model_var_values
model_variance = th.exp(model_log_variance)
else:
min_log = _extract_into_tensor(
self.posterior_log_variance_clipped, t, x.shape
)
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
@ -398,26 +366,18 @@ class GaussianDiffusion:
return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart(
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))
model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output)
else:
pred_xstart = process_xstart(
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
else:
raise NotImplementedError(self.model_mean_type)
assert (
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return {
"mean": model_mean,
"variance": model_variance,
@ -436,16 +396,12 @@ class GaussianDiffusion:
assert x_t.shape == xprev.shape
return ( # (xprev - coef2*x_t) / coef1
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
- _extract_into_tensor(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
)
* x_t
- _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t
)
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- pred_xstart
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _scale_timesteps(self, t):
@ -463,9 +419,7 @@ class GaussianDiffusion:
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
"""
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
new_mean = (
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
@ -481,16 +435,13 @@ class GaussianDiffusion:
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
x, self._scale_timesteps(t), **model_kwargs
)
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs)
out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance(
x_start=out["pred_xstart"], x_t=x, t=t
)
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
return out
def k_diffusion_sample_loop(
self,
k_sampler,
@ -512,9 +463,7 @@ class GaussianDiffusion:
def model_split(*args, **kwargs):
model_output = model(*args, **kwargs)
model_epsilon, model_var = th.split(
model_output, model_output.shape[1] // 2, dim=1
)
model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1)
return model_epsilon, model_var
#
@ -523,9 +472,7 @@ class GaussianDiffusion:
print(th.tensor(self.betas))
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas))
"""
noise_schedule = NoiseScheduleVP(
schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4
)
noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4)
def model_fn_prewrap(x, t, *args, **kwargs):
"""
@ -584,11 +531,10 @@ class GaussianDiffusion:
if self.conditioning_free is not True:
raise RuntimeError("cond_free must be true")
with tqdm(total=self.num_timesteps) as pbar:
return self.k_diffusion_sample_loop(
K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs
)
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
else:
raise RuntimeError("sampler not impl")
def p_sample(
self,
model,
@ -625,13 +571,9 @@ class GaussianDiffusion:
model_kwargs=model_kwargs,
)
noise = th.randn_like(x)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
if cond_fn is not None:
out["mean"] = self.condition_mean(
cond_fn, out, x, t, model_kwargs=model_kwargs
)
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
@ -758,20 +700,11 @@ class GaussianDiffusion:
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = (
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
# Equation 12.
noise = th.randn_like(x)
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
@ -800,16 +733,12 @@ class GaussianDiffusion:
# Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction.
eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
- out["pred_xstart"]
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed
mean_pred = (
out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ th.sqrt(1 - alpha_bar_next) * eps
)
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
@ -897,9 +826,7 @@ class GaussianDiffusion:
yield out
img = out["sample"]
def _vb_terms_bpd(
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
"""
Get a term for the variational lower-bound.
@ -910,15 +837,9 @@ class GaussianDiffusion:
- 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions.
"""
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
@ -969,7 +890,7 @@ class GaussianDiffusion:
model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
if isinstance(model_outputs, tuple):
model_output = model_outputs[0]
terms['extra_outputs'] = model_outputs[1:]
terms["extra_outputs"] = model_outputs[1:]
else:
model_output = model_outputs
@ -996,9 +917,7 @@ class GaussianDiffusion:
terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0]
target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0]
x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X:
target = x_start
@ -1020,7 +939,9 @@ class GaussianDiffusion:
return terms
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
def autoregressive_training_losses(
self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None
):
"""
Compute training losses for a single timestep.
@ -1068,9 +989,7 @@ class GaussianDiffusion:
terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance(
x_start=x_start, x_t=x_t, t=t
)[0]
target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0]
x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X:
target = x_start
@ -1105,9 +1024,7 @@ class GaussianDiffusion:
batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
@ -1183,9 +1100,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif schedule_name == "cosine":
return betas_for_alpha_bar(
num_diffusion_timesteps,
@ -1219,19 +1134,13 @@ class SpacedDiffusion(GaussianDiffusion):
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def autoregressive_training_losses(
self, model, *args, **kwargs
): # pylint: disable=signature-differs
def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
@ -1244,9 +1153,7 @@ class SpacedDiffusion(GaussianDiffusion):
if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
return model
mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
return mod(
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
@ -1281,9 +1188,7 @@ def space_timesteps(num_timesteps, section_counts):
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
@ -1292,9 +1197,7 @@ def space_timesteps(num_timesteps, section_counts):
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}"
)
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
if section_count <= 1:
frac_stride = 1
else:
@ -1315,6 +1218,7 @@ class _WrappedModel:
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts]
@ -1323,6 +1227,7 @@ class _WrappedModel:
model_output = self.model(x, new_ts, **kwargs)
return model_output
class _WrappedAutoregressiveModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
@ -1337,6 +1242,7 @@ class _WrappedAutoregressiveModel:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, x0, new_ts, **kwargs)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.

View File

@ -29,11 +29,9 @@ def timestep_embedding(timesteps, dim, max_period=10000):
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
@ -98,17 +96,13 @@ class ResBlock(TimestepBlock):
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
nn.Conv1d(
self.out_channels, self.out_channels, kernel_size, padding=padding
),
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
else:
self.skip_connection = nn.Conv1d(
channels, self.out_channels, eff_kernel, padding=eff_padding
)
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
def forward(self, x, emb):
h = self.in_layers(x)
@ -137,9 +131,7 @@ class DiffusionLayer(TimestepBlock):
dims=1,
use_scale_shift_norm=True,
)
self.attn = AttentionBlock(
model_channels, num_heads, relative_pos_embeddings=True
)
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
@ -239,16 +231,11 @@ class DiffusionTts(nn.Module):
DiffusionLayer(model_channels, dropout, num_heads),
)
self.integrating_conv = nn.Conv1d(
model_channels * 2, model_channels, kernel_size=1
)
self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.layers = nn.ModuleList(
[
DiffusionLayer(model_channels, dropout, num_heads)
for _ in range(num_layers)
]
[DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)]
+ [
ResBlock(
model_channels,
@ -275,9 +262,7 @@ class DiffusionTts(nn.Module):
+ list(self.code_converter.parameters())
+ list(self.latent_conditioner.parameters())
+ list(self.latent_conditioner.parameters()),
"timestep_integrator": list(
self.conditioning_timestep_integrator.parameters()
)
"timestep_integrator": list(self.conditioning_timestep_integrator.parameters())
+ list(self.integrating_conv.parameters()),
"time_embed": list(self.time_embed.parameters()),
}
@ -285,9 +270,7 @@ class DiffusionTts(nn.Module):
def get_conditioning(self, conditioning_input):
speech_conditioning_input = (
conditioning_input.unsqueeze(1)
if len(conditioning_input.shape) == 3
else conditioning_input
conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
)
conds = []
for j in range(speech_conditioning_input.shape[1]):
@ -313,29 +296,20 @@ class DiffusionTts(nn.Module):
else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb)
code_emb = self.code_norm(code_emb) * (
1 + cond_scale.unsqueeze(-1)
) + cond_shift.unsqueeze(-1)
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros(
(code_emb.shape[0], 1, 1), device=code_emb.device
)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = (
torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device)
< self.unconditioned_percentage
torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage
)
code_emb = torch.where(
unconditioned_batches,
self.unconditioned_embedding.repeat(
aligned_conditioning.shape[0], 1, 1
),
self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb,
)
expanded_code_emb = F.interpolate(
code_emb, size=expected_seq_len, mode="nearest"
)
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest")
if not return_code_pred:
return expanded_code_emb
@ -376,10 +350,7 @@ class DiffusionTts(nn.Module):
unused_params = []
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(
list(self.code_converter.parameters())
+ list(self.code_embedding.parameters())
)
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_conditioner.parameters()))
else:
if precomputed_aligned_embeddings is not None:
@ -390,8 +361,7 @@ class DiffusionTts(nn.Module):
)
if is_latent(aligned_conditioning):
unused_params.extend(
list(self.code_converter.parameters())
+ list(self.code_embedding.parameters())
list(self.code_converter.parameters()) + list(self.code_embedding.parameters())
)
else:
unused_params.extend(list(self.latent_conditioner.parameters()))

View File

@ -1,6 +1,8 @@
import math
import torch
class NoiseScheduleVP:
def __init__(
self,
@ -107,11 +109,7 @@ class NoiseScheduleVP:
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas)
self.T = 1.0
self.t_array = (
torch.linspace(0.0, 1.0, self.total_N + 1)[1:]
.reshape((1, -1))
.to(dtype=dtype)
)
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
self.log_alpha_array = log_alphas.reshape(
(
1,
@ -131,9 +129,7 @@ class NoiseScheduleVP:
/ math.pi
- self.cosine_s
)
self.cosine_log_alpha_0 = math.log(
math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
)
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
self.schedule = schedule
if schedule == "cosine":
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
@ -157,11 +153,7 @@ class NoiseScheduleVP:
elif self.schedule == "cosine":
def log_alpha_fn(s):
return torch.log(
torch.cos(
(s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0
)
)
return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
@ -191,17 +183,11 @@ class NoiseScheduleVP:
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == "linear":
tmp = (
2.0
* (self.beta_1 - self.beta_0)
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
)
tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == "discrete":
log_alpha = -0.5 * torch.logaddexp(
torch.zeros((1,)).to(lamb.device), -2.0 * lamb
)
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
t = interpolate_fn(
log_alpha.reshape((-1, 1)),
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
@ -345,14 +331,10 @@ def model_wrapper(
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(
t_continuous
), noise_schedule.marginal_std(t_continuous)
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return (x - alpha_t * output) / sigma_t
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(
t_continuous
), noise_schedule.marginal_std(t_continuous)
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return alpha_t * output + sigma_t * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
@ -482,9 +464,7 @@ class DPM_Solver:
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(
torch.maximum(
s, self.thresholding_max_val * torch.ones_like(s).to(s.device)
),
torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)),
dims,
)
x0 = torch.clamp(x0, -s, s) / s
@ -501,9 +481,7 @@ class DPM_Solver:
Return the data prediction model (with corrector).
"""
noise = self.noise_prediction_fn(x, t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
t
), self.noise_schedule.marginal_std(t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - sigma_t * noise) / alpha_t
if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0, t)
@ -536,30 +514,20 @@ class DPM_Solver:
if skip_type == "logSNR":
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(
lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
).to(device)
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == "time_uniform":
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == "time_quadratic":
t_order = 2
t = (
torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
.pow(t_order)
.to(device)
)
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
return t
else:
raise ValueError(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
skip_type
)
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
)
def get_orders_and_timesteps_for_singlestep_solver(
self, steps, order, skip_type, t_T, t_0, device
):
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
@ -664,9 +632,7 @@ class DPM_Solver:
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(
s
), ns.marginal_log_mean_coeff(t)
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
@ -716,11 +682,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
if r1 is None:
r1 = 0.5
ns = self.noise_schedule
@ -766,10 +728,7 @@ class DPM_Solver:
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
torch.exp(log_alpha_s1 - log_alpha_s) * x
- (sigma_s1 * phi_11) * model_s
)
x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpmsolver":
x_t = (
@ -820,11 +779,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
if r1 is None:
r1 = 1.0 / 3.0
if r2 is None:
@ -901,9 +856,7 @@ class DPM_Solver:
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (
sigma_s1 * phi_11
) * model_s
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
@ -934,9 +887,7 @@ class DPM_Solver:
else:
return x_t
def multistep_dpm_solver_second_update(
self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
):
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
"""
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
@ -951,11 +902,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
@ -964,9 +911,7 @@ class DPM_Solver:
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
t_prev_0
), ns.marginal_log_mean_coeff(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
@ -977,11 +922,7 @@ class DPM_Solver:
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
if solver_type == "dpmsolver":
x_t = (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
- 0.5 * (alpha_t * phi_1) * D1_0
)
x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
elif solver_type == "taylor":
x_t = (
(sigma_t / sigma_prev_0) * x
@ -1004,9 +945,7 @@ class DPM_Solver:
)
return x_t
def multistep_dpm_solver_third_update(
self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
):
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
"""
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
@ -1029,9 +968,7 @@ class DPM_Solver:
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
t_prev_0
), ns.marginal_log_mean_coeff(t)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
@ -1093,9 +1030,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(
x, s, t, return_intermediate=return_intermediate
)
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
elif order == 2:
return self.singlestep_dpm_solver_second_update(
x,
@ -1118,9 +1053,7 @@ class DPM_Solver:
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def multistep_dpm_solver_update(
self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"
):
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
"""
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
@ -1136,17 +1069,11 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(
x, t_prev_list[-1], t, model_s=model_prev_list[-1]
)
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
elif order == 2:
return self.multistep_dpm_solver_second_update(
x, model_prev_list, t_prev_list, t, solver_type=solver_type
)
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
elif order == 3:
return self.multistep_dpm_solver_third_update(
x, model_prev_list, t_prev_list, t, solver_type=solver_type
)
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
@ -1198,9 +1125,7 @@ class DPM_Solver:
return self.dpm_solver_first_update(x, s, t, return_intermediate=True)
def higher_update(x, s, t, **kwargs):
return self.singlestep_dpm_solver_second_update(
x, s, t, r1=r1, solver_type=solver_type, **kwargs
)
return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
elif order == 3:
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
@ -1211,16 +1136,10 @@ class DPM_Solver:
)
def higher_update(x, s, t, **kwargs):
return self.singlestep_dpm_solver_third_update(
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
)
return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
else:
raise ValueError(
"For adaptive step size solver, order must be 2 or 3, got {}".format(
order
)
)
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
while torch.abs((s - t_0)).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t)
@ -1231,9 +1150,7 @@ class DPM_Solver:
)
def norm_fn(v):
return torch.sqrt(
torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
)
return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.0):
@ -1259,9 +1176,7 @@ class DPM_Solver:
Returns:
xt with shape `(t_size, batch_size, *shape)`.
"""
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
t
), self.noise_schedule.marginal_std(t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
if noise is None:
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
x = x.reshape((-1, *x.shape))
@ -1468,9 +1383,7 @@ class DPM_Solver:
)
elif method == "multistep":
assert steps >= order
timesteps = self.get_time_steps(
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
)
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps
# Init the initial values.
step = 0
@ -1527,10 +1440,7 @@ class DPM_Solver:
model_prev_list[-1] = self.model_fn(x, t)
elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep":
(
timesteps_outer,
orders,
) = self.get_orders_and_timesteps_for_singlestep_solver(
(timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps,
order=order,
skip_type=skip_type,
@ -1543,9 +1453,7 @@ class DPM_Solver:
orders = [
order,
] * K
timesteps_outer = self.get_time_steps(
skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
)
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
for step, order in enumerate(orders):
s, t = timesteps_outer[step], timesteps_outer[step + 1]
timesteps_inner = self.get_time_steps(
@ -1559,9 +1467,7 @@ class DPM_Solver:
h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update(
x, s, t, order, solver_type=solver_type, r1=r1, r2=r2
)
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
@ -1613,9 +1519,7 @@ def interpolate_fn(x, xp, yp):
cand_start_idx,
),
)
end_idx = torch.where(
torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
@ -1628,12 +1532,8 @@ def interpolate_fn(x, xp, yp):
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(
y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
).squeeze(2)
end_y = torch.gather(
y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
).squeeze(2)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand

View File

@ -40,8 +40,7 @@ class RandomLatentConverter(nn.Module):
def __init__(self, channels):
super().__init__()
self.layers = nn.Sequential(
*[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)],
nn.Linear(channels, channels)
*[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels)
)
self.channels = channels

View File

@ -95,9 +95,7 @@ def _expand_number(m):
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(
num, andword="", zero="oh", group=2
).replace(", ", " ")
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
@ -165,9 +163,7 @@ def lev_distance(s1, s2):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(
1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
)
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]

View File

@ -36,12 +36,8 @@ def route_args(router, args, depth):
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(
zip(routed_args, router[key])
):
new_f_args, new_g_args = map(
lambda route: ({key: val} if route else {}), routes
)
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
@ -217,12 +213,8 @@ class Transformer(nn.Module):
layers.append(
nn.ModuleList(
[
LayerScale(
dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)
),
LayerScale(
dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)
),
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)),
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)),
]
)
)

View File

@ -1,5 +1,7 @@
import os
try: import gdown
try:
import gdown
except ImportError:
raise ImportError(
"Sorry, gdown is required in order to download the new BigVGAN vocoder.\n"
@ -11,9 +13,7 @@ import progressbar
D_STEM = "https://drive.google.com/uc?id="
DEFAULT_MODELS_DIR = os.path.join(
os.path.expanduser("~"), ".cache", "tortoise", "models"
)
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS = {
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
@ -30,6 +30,8 @@ MODELS = {
}
pbar = None
def download_models(specific_models=None):
"""
Call to download all the models that Tortoise uses.
@ -62,6 +64,7 @@ def download_models(specific_models=None):
request.urlretrieve(url, model_path, show_progress)
print("Done.")
def get_model_path(model_name, models_dir=MODELS_DIR):
"""
Get path to given model, download it if it doesn't exist.

View File

@ -1,12 +1,12 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from enum import Enum
from typing import Optional, Callable
from dataclasses import dataclass
MAX_WAV_VALUE = 32768.0
@ -40,18 +40,12 @@ class KernelPredictor(torch.nn.Module):
self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers
kpnet_kernel_channels = (
conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
) # l_w
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential(
nn.utils.weight_norm(
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
),
getattr(nn, kpnet_nonlinear_activation)(
**kpnet_nonlinear_activation_params
),
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
self.residual_convs = nn.ModuleList()
@ -69,9 +63,7 @@ class KernelPredictor(torch.nn.Module):
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(
**kpnet_nonlinear_activation_params
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
nn.utils.weight_norm(
nn.Conv1d(
kpnet_hidden_channels,
@ -81,9 +73,7 @@ class KernelPredictor(torch.nn.Module):
bias=True,
)
),
getattr(nn, kpnet_nonlinear_activation)(
**kpnet_nonlinear_activation_params
),
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
)
)
self.kernel_conv = nn.utils.weight_norm(
@ -252,17 +242,11 @@ class LVCBlock(torch.nn.Module):
"""
batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == (
kernel_length * hop_size
), "length of (x, kernel) is not matched"
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2)
x = F.pad(
x, (padding, padding), "constant", 0
) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(
2, hop_size + 2 * padding, hop_size
) # (batch, in_channels, kernel_length, hop_size + 2*padding)
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0)
@ -270,12 +254,8 @@ class LVCBlock(torch.nn.Module):
3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size]
x = x.transpose(
3, 4
) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(
4, kernel_size, 1
) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d)
@ -334,15 +314,11 @@ class UnivNetGenerator(nn.Module):
)
)
self.conv_pre = nn.utils.weight_norm(
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
)
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm(
nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")
),
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Tanh(),
)
@ -399,12 +375,16 @@ class VocType:
constructor: Callable[[], nn.Module]
model_path: str
subkey: Optional[str] = None
def optionally_index(self, model_dict):
if self.subkey is not None:
return model_dict[self.subkey]
return model_dict
class VocConf(Enum):
Univnet = VocType(UnivNetGenerator, "vocoder.pth", 'model_g')
Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g")
if __name__ == "__main__":
model = UnivNetGenerator()

View File

@ -12,9 +12,7 @@ def max_alignment(s1, s2, skip_character="~", record=None):
"""
if record is None:
record = {}
assert (
skip_character not in s1
), f"Found the skip character {skip_character} in the provided string, {s1}"
assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
if len(s1) == 0:
return ""
if len(s2) == 0:
@ -49,15 +47,9 @@ class Wav2VecAlignment:
"""
def __init__(self, device="cuda"):
self.model = Wav2Vec2ForCTC.from_pretrained(
"jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli"
).cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-large-960h"
)
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
"jbetker/tacotron-symbols"
)
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h")
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols")
self.device = device
def align(self, audio, expected_text, audio_sample_rate=24000):
@ -117,9 +109,7 @@ class Wav2VecAlignment:
)
# Now fix up alignments. Anything with -1 should be interpolated.
alignments.append(
orig_len
) # This'll get removed but makes the algorithm below more readable.
alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
for i in range(len(alignments)):
if alignments[i] == -1:
for j in range(i + 1, len(alignments)):
@ -128,9 +118,7 @@ class Wav2VecAlignment:
break
for j in range(i, next_found_token):
gap = alignments[next_found_token] - alignments[i - 1]
alignments[j] = (j - i + 1) * gap // (
next_found_token - i + 1
) + alignments[i - 1]
alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1]
return alignments[:-1]
@ -140,9 +128,7 @@ class Wav2VecAlignment:
splitted = expected_text.split("[")
fully_split = [splitted[0]]
for spl in splitted[1:]:
assert (
"]" in spl
), 'Every "[" character must be paired with a "]" with no nesting.'
assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.'
fully_split.extend(spl.split("]"))
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.

View File

@ -6,24 +6,25 @@ from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn, einsum
from torch import einsum, nn
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
'pre_softmax_attn',
'post_softmax_attn'
])
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
LayerIntermediates = namedtuple('Intermediates', [
'hiddens',
'attn_intermediates',
'past_key_values',
])
LayerIntermediates = namedtuple(
"Intermediates",
[
"hiddens",
"attn_intermediates",
"past_key_values",
],
)
# helpers
def exists(val):
return val is not None
@ -38,7 +39,7 @@ def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
class always():
class always:
def __init__(self, val):
self.val = val
@ -46,7 +47,7 @@ class always():
return self.val
class not_equals():
class not_equals:
def __init__(self, val):
self.val = val
@ -54,7 +55,7 @@ class not_equals():
return x != self.val
class equals():
class equals:
def __init__(self, val):
self.val = val
@ -72,14 +73,16 @@ def l2norm(t):
# init helpers
def init_zero_(layer):
nn.init.constant_(layer.weight, 0.)
nn.init.constant_(layer.weight, 0.0)
if exists(layer.bias):
nn.init.constant_(layer.bias, 0.)
nn.init.constant_(layer.bias, 0.0)
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
@ -104,12 +107,13 @@ def group_by_key_prefix(prefix, d):
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# activations
class ReluSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2
@ -117,30 +121,31 @@ class ReluSquared(nn.Module):
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
pos_emb = self.emb(n)
pos_emb = rearrange(pos_emb, 'n d -> () n d')
pos_emb = rearrange(pos_emb, "n d -> () n d")
return pos_emb * self.scale
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return rearrange(emb, 'n d -> () n d')
return rearrange(emb, "n d -> () n d")
class RelativePositionBias(nn.Module):
@ -166,9 +171,10 @@ class RelativePositionBias(nn.Module):
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = (
max_exact
+ (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
@ -179,10 +185,11 @@ class RelativePositionBias(nn.Module):
q_pos = torch.arange(i, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets,
max_distance=self.max_distance)
rp_bucket = self._relative_position_bucket(
rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance
)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, 'i j h -> () h i j')
bias = rearrange(values, "i j h -> () h i j")
return qk_dots + (bias * self.scale)
@ -191,23 +198,25 @@ class AlibiPositionalBias(nn.Module):
super().__init__()
self.heads = heads
slopes = torch.Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> () h () ()')
self.register_buffer('slopes', slopes, persistent=False)
self.register_buffer('bias', None, persistent=False)
slopes = rearrange(slopes, "h -> () h () ()")
self.register_buffer("slopes", slopes, persistent=False)
self.register_buffer("bias", None, persistent=False)
@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
return [start * ratio**i for i in range(n)]
if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
:heads - closest_power_of_2]
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: heads - closest_power_of_2]
)
def forward(self, qk_dots):
h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
@ -216,13 +225,13 @@ class AlibiPositionalBias(nn.Module):
return qk_dots + self.bias[..., :j]
bias = torch.arange(j, device=device)
bias = rearrange(bias, 'j -> () () () j')
bias = rearrange(bias, "j -> () () () j")
bias = bias * self.slopes
num_heads_unalibied = h - bias.shape[1]
bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
self.register_buffer('bias', bias, persistent=False)
self.register_buffer("bias", bias, persistent=False)
return qk_dots + self.bias
@ -247,8 +256,8 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
else:
i_arange = torch.arange(i, device=device)
j_arange = torch.arange(j, device=device)
bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1')
self.register_buffer('bias', bias, persistent=False)
bias = rearrange(j_arange, "j -> 1 1 1 j") - rearrange(i_arange, "i -> 1 1 i 1")
self.register_buffer("bias", bias, persistent=False)
if self.bidirectional:
past_slopes = get_slopes(self.learned_logslopes)
@ -264,18 +273,18 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias):
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, max_seq_len, device):
t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return rearrange(emb, 'n d -> () () n d')
return rearrange(emb, "n d -> () () n d")
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j=2)
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
@ -288,6 +297,7 @@ def apply_rotary_pos_emb(t, freqs):
# norms
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
@ -323,7 +333,7 @@ class Rezero(nn.Module):
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
@ -335,7 +345,7 @@ class ScaleNorm(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
@ -347,7 +357,7 @@ class RMSNorm(nn.Module):
class RMSScaleShiftNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
self.scale_shift_process = nn.Linear(dim * 2, dim * 2)
@ -364,6 +374,7 @@ class RMSScaleShiftNorm(nn.Module):
# residual and residual gates
class Residual(nn.Module):
def __init__(self, dim, scale_residual=False):
super().__init__()
@ -386,24 +397,22 @@ class GRUGating(nn.Module):
if exists(self.residual_scale):
residual = residual * self.residual_scale
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)
gated_output = self.gru(rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d"))
return gated_output.reshape_as(x)
# token shifting
def shift(t, amount, mask=None):
if amount == 0:
return t
if exists(mask):
t = t.masked_fill(~mask[..., None], 0.)
t = t.masked_fill(~mask[..., None], 0.0)
return F.pad(t, (0, 0, amount, -amount), value=0.)
return F.pad(t, (0, 0, amount, -amount), value=0.0)
class ShiftTokens(nn.Module):
@ -413,7 +422,7 @@ class ShiftTokens(nn.Module):
self.shifts = tuple(shifts)
def forward(self, x, **kwargs):
mask = kwargs.get('mask', None)
mask = kwargs.get("mask", None)
shifts = self.shifts
segments = len(shifts)
feats_per_shift = x.shape[-1] // segments
@ -426,6 +435,7 @@ class ShiftTokens(nn.Module):
# feedforward
class GLU(nn.Module):
def __init__(self, dim_in, dim_out, activation):
super().__init__()
@ -446,24 +456,23 @@ class FeedForward(nn.Module):
glu=False,
relu_squared=False,
post_act_ln=False,
dropout=0.,
zero_init_output=False
dropout=0.0,
zero_init_output=False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
activation = ReluSquared() if relu_squared else nn.GELU()
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
activation
) if not glu else GLU(dim, inner_dim, activation)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), activation) if not glu else GLU(dim, inner_dim, activation)
)
self.net = nn.Sequential(
project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
nn.Linear(inner_dim, dim_out),
)
# init last linear layer to 0
@ -476,6 +485,7 @@ class FeedForward(nn.Module):
# attention.
class Attention(nn.Module):
def __init__(
self,
@ -486,11 +496,11 @@ class Attention(nn.Module):
talking_heads=False,
head_scale=False,
collab_heads=False,
collab_compression=.3,
collab_compression=0.3,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
dropout=0.0,
on_attn=False,
gate_values=False,
zero_init_output=False,
@ -502,7 +512,7 @@ class Attention(nn.Module):
rel_pos_max_distance=128,
):
super().__init__()
self.scale = dim_head ** -0.5
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
@ -532,8 +542,9 @@ class Attention(nn.Module):
# cosine sim attention
self.qk_norm = qk_norm
if qk_norm:
scale_init_value = default(scale_init_value,
-3) # if not provided, initialize as though it were sequence length of 1024
scale_init_value = default(
scale_init_value, -3
) # if not provided, initialize as though it were sequence length of 1024
self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value)
# talking heads
@ -565,9 +576,16 @@ class Attention(nn.Module):
self.rel_pos_bias = rel_pos_bias
if rel_pos_bias:
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads,
num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance)
assert (
rel_pos_num_buckets <= rel_pos_max_distance
), "number of relative position buckets must be less than the relative position max distance"
self.rel_pos = RelativePositionBias(
scale=dim_head**0.5,
causal=causal,
heads=heads,
num_buckets=rel_pos_num_buckets,
max_distance=rel_pos_max_distance,
)
# init output projection 0
if zero_init_output:
@ -586,8 +604,16 @@ class Attention(nn.Module):
mem=None,
layer_past=None,
):
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists(
context)
b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = (
*x.shape,
self.heads,
self.talking_heads,
self.collab_heads,
self.head_scale,
self.scale,
x.device,
exists(context),
)
kv_input = default(context, x)
q_input = x
@ -609,11 +635,11 @@ class Attention(nn.Module):
v = self.to_v(v_input)
if not collab_heads:
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
else:
q = einsum('b i d, h d -> b h i d', q, self.collab_mixing)
k = rearrange(k, 'b n d -> b () n d')
v = rearrange(v, 'b n (h d) -> b h n d', h=h)
q = einsum("b i d, h d -> b h i d", q, self.collab_mixing)
k = rearrange(k, "b n d -> b () n d")
v = rearrange(v, "b n (h d) -> b h n d", h=h)
if layer_past is not None:
past_key, past_value = layer_past
@ -633,12 +659,12 @@ class Attention(nn.Module):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
q_mask = rearrange(q_mask, "b i -> b () i ()")
k_mask = rearrange(k_mask, "b j -> b () () j")
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
@ -651,7 +677,7 @@ class Attention(nn.Module):
q, k = map(l2norm, (q, k))
scale = 1 / (self.scale.exp().clamp(min=1e-2))
dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale
dots = einsum("b h i d, b h j d -> b h i j", q, k) * scale
mask_value = max_neg_value(dots)
if exists(prev_attn):
@ -660,7 +686,7 @@ class Attention(nn.Module):
pre_softmax_attn = dots.clone()
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
dots = einsum("b h i j, h k -> b k i j", dots, self.pre_softmax_proj).contiguous()
if self.rel_pos_bias:
dots = self.rel_pos(dots)
@ -670,18 +696,20 @@ class Attention(nn.Module):
del input_mask
if exists(attn_mask):
assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
assert (
2 <= attn_mask.ndim <= 4
), "attention mask must have greater than 2 dimensions but less than or equal to 4"
if attn_mask.ndim == 2:
attn_mask = rearrange(attn_mask, 'i j -> () () i j')
attn_mask = rearrange(attn_mask, "i j -> () () i j")
elif attn_mask.ndim == 3:
attn_mask = rearrange(attn_mask, 'h i j -> () h i j')
attn_mask = rearrange(attn_mask, "h i j -> () h i j")
dots.masked_fill_(~attn_mask, mask_value)
if exists(self.max_attend_past):
i, j = dots.shape[-2:]
range_q = torch.arange(j - i, j, device=device)
range_k = torch.arange(j, device=device)
dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j')
dist = rearrange(range_q, "i -> () () i ()") - rearrange(range_k, "j -> () () () j")
mask = dist > self.max_attend_past
dots.masked_fill_(mask, mask_value)
del mask
@ -689,7 +717,7 @@ class Attention(nn.Module):
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j")
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
@ -707,23 +735,20 @@ class Attention(nn.Module):
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
attn = einsum("b h i j, h k -> b k i j", attn, self.post_softmax_proj).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
if head_scale:
out = out * self.head_scale_params
out = rearrange(out, 'b h n d -> b n (h d)')
out = rearrange(out, "b h n d -> b n (h d)")
if exists(self.to_v_gate):
gates = self.to_v_gate(x)
out = out * gates.sigmoid()
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
)
intermediates = Intermediates(pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn)
return self.to_out(out), intermediates, k_cache, v_cache
@ -761,20 +786,20 @@ class AttentionLayers(nn.Module):
use_qk_norm_attn=False,
qk_norm_attn_seq_len=None,
zero_init_branch_output=False,
**kwargs
**kwargs,
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs)
attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.causal = causal
rel_pos_bias = 'rel_pos_bias' in attn_kwargs
rel_pos_bias = "rel_pos_bias" in attn_kwargs
self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
@ -782,17 +807,18 @@ class AttentionLayers(nn.Module):
self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None
assert not (
alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
alibi_pos_bias and rel_pos_bias
), "you can only choose Alibi positional bias or T5 relative positional bias, not both"
if alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
assert alibi_num_heads <= heads, "number of ALiBi heads must be less than the total number of heads"
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal)
else:
self.rel_pos = None
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
assert not (not pre_norm and sandwich_norm), "sandwich norm cannot be used when not using prenorm"
self.pre_norm = pre_norm
self.sandwich_norm = sandwich_norm
@ -809,27 +835,30 @@ class AttentionLayers(nn.Module):
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ('a', 'c', 'f')
default_block = ("a", "c", "f")
elif cross_attend and only_cross:
default_block = ('c', 'f')
default_block = ("c", "f")
else:
default_block = ('a', 'f')
default_block = ("a", "f")
if macaron:
default_block = ('f',) + default_block
default_block = ("f",) + default_block
# qk normalization
if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(
qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
attn_scale_init_value = (
-math.log(math.log2(qk_norm_attn_seq_len**2 - qk_norm_attn_seq_len))
if exists(qk_norm_attn_seq_len)
else None
)
attn_kwargs = {**attn_kwargs, "qk_norm": True, "scale_init_value": attn_scale_init_value}
# zero init
if zero_init_branch_output:
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
attn_kwargs = {**attn_kwargs, "zero_init_output": True}
ff_kwargs = {**ff_kwargs, "zero_init_output": True}
# calculate layer block order
@ -837,23 +866,23 @@ class AttentionLayers(nn.Module):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
assert 1 < par_ratio <= par_depth, "par ratio out of range"
default_block = tuple(filter(not_equals("f"), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
assert len(default_block) <= par_width, "default block is too large for par_ratio"
par_block = default_block + ("f",) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
layer_types = par_head + ("f",) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
assert sandwich_coef > 0 and sandwich_coef <= depth, "sandwich coefficient should be less than the depth"
layer_types = ("a",) * sandwich_coef + default_block * (depth - sandwich_coef) + ("f",) * sandwich_coef
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
self.num_attn_layers = len(list(filter(equals("a"), layer_types)))
# calculate token shifting
@ -864,15 +893,15 @@ class AttentionLayers(nn.Module):
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
is_last_layer = ind == (len(self.layer_types) - 1)
if layer_type == 'a':
if layer_type == "a":
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c':
elif layer_type == "c":
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
elif layer_type == "f":
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f'invalid layer type {layer_type}')
raise Exception(f"invalid layer type {layer_type}")
if layer_shift_tokens > 0:
shift_range_upper = layer_shift_tokens + 1
@ -885,23 +914,15 @@ class AttentionLayers(nn.Module):
residual_fn = GRUGating if gate_residual else Residual
residual = residual_fn(dim, scale_residual=scale_residual)
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c')
layer_uses_qk_norm = use_qk_norm_attn and layer_type in ("a", "c")
pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None
post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
norms = nn.ModuleList([
pre_branch_norm,
post_branch_norm,
post_main_norm
])
norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm])
self.layers.append(nn.ModuleList([
norms,
layer,
residual
]))
self.layers.append(nn.ModuleList([norms, layer, residual]))
def forward(
self,
@ -918,9 +939,10 @@ class AttentionLayers(nn.Module):
expected_seq_len=None,
):
assert not (self.cross_attend ^ (exists(context) or exists(
full_context))), 'context must be passed in if cross_attend is set to True'
assert context is None or full_context is None, 'only one of full_context or context can be provided'
assert not (
self.cross_attend ^ (exists(context) or exists(full_context))
), "context must be passed in if cross_attend is set to True"
assert context is None or full_context is None, "only one of full_context or context can be provided"
hiddens = []
intermediates = []
@ -930,24 +952,28 @@ class AttentionLayers(nn.Module):
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
norm_args = {}
if exists(norm_scale_shift_inp):
norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp
norm_args["norm_scale_shift_inp"] = norm_scale_shift_inp
rotary_pos_emb = None
if exists(self.rotary_pos_emb):
if not self.training and self.causal:
assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
assert (
expected_seq_len is not None
), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`"
elif expected_seq_len is None:
expected_seq_len = 0
seq_len = x.shape[1]
if past_key_values is not None:
seq_len += past_key_values[0][0].shape[-2]
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len])
max_rotary_emb_length = max(
list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]
)
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
present_key_values = []
cross_attn_count = 0
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
if layer_type == 'a':
if layer_type == "a":
layer_mem = mems.pop(0) if mems else None
residual = x
@ -957,26 +983,39 @@ class AttentionLayers(nn.Module):
if exists(pre_branch_norm):
x = pre_branch_norm(x, **norm_args)
if layer_type == 'a' or layer_type == 'c':
if layer_type == "a" or layer_type == "c":
if past_key_values is not None:
layer_kv = past_key_values.pop(0)
layer_past = tuple(s.to(x.device) for s in layer_kv)
else:
layer_past = None
if layer_type == 'a':
out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb,
prev_attn, layer_mem, layer_past)
elif layer_type == 'c':
if layer_type == "a":
out, inter, k, v = block(
x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem, layer_past
)
elif layer_type == "c":
if exists(full_context):
out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None,
None, prev_attn, None, layer_past)
out, inter, k, v = block(
x,
full_context[cross_attn_count],
mask,
context_mask,
None,
None,
None,
prev_attn,
None,
layer_past,
)
else:
out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past)
elif layer_type == 'f':
out, inter, k, v = block(
x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past
)
elif layer_type == "f":
out = block(x)
if layer_type == 'a' or layer_type == 'c' and present_key_values is not None:
if layer_type == "a" or layer_type == "c" and present_key_values is not None:
present_key_values.append((k.detach(), v.detach()))
if exists(post_branch_norm):
@ -984,28 +1023,26 @@ class AttentionLayers(nn.Module):
x = residual_fn(out, residual)
if layer_type in ('a', 'c'):
if layer_type in ("a", "c"):
intermediates.append(inter)
if layer_type == 'a' and self.residual_attn:
if layer_type == "a" and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == 'c' and self.cross_residual_attn:
elif layer_type == "c" and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if exists(post_main_norm):
x = post_main_norm(x, **norm_args)
if layer_type == 'c':
if layer_type == "c":
cross_attn_count += 1
if layer_type == 'f':
if layer_type == "f":
hiddens.append(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens,
attn_intermediates=intermediates,
past_key_values=present_key_values
hiddens=hiddens, attn_intermediates=intermediates, past_key_values=present_key_values
)
return x, intermediates
@ -1015,13 +1052,13 @@ class AttentionLayers(nn.Module):
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
assert "causal" not in kwargs, "cannot set causality on encoder"
super().__init__(causal=False, **kwargs)
class Decoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
assert "causal" not in kwargs, "cannot set causality on decoder"
super().__init__(causal=True, **kwargs)
@ -1031,22 +1068,13 @@ class CrossAttender(AttentionLayers):
class ViTransformerWrapper(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
attn_layers,
num_classes=None,
dropout=0.,
emb_dropout=0.
):
def __init__(self, *, image_size, patch_size, attn_layers, num_classes=None, dropout=0.0, emb_dropout=0.0):
super().__init__()
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder"
assert image_size % patch_size == 0, "image dimensions must be divisible by the patch size"
dim = attn_layers.dim
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
patch_dim = 3 * patch_size**2
self.patch_size = patch_size
@ -1059,20 +1087,16 @@ class ViTransformerWrapper(nn.Module):
self.norm = nn.LayerNorm(dim)
self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None
def forward(
self,
img,
return_embeddings=False
):
def forward(self, img, return_embeddings=False):
p = self.patch_size
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embedding[:, :(n + 1)]
x = x + self.pos_embedding[:, : (n + 1)]
x = self.dropout(x)
x = self.attn_layers(x)
@ -1092,15 +1116,15 @@ class TransformerWrapper(nn.Module):
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
max_mem_len=0.0,
shift_mem_down=0,
emb_dropout=0.,
emb_dropout=0.0,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True
use_pos_emb=True,
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder"
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
@ -1110,8 +1134,11 @@ class TransformerWrapper(nn.Module):
self.shift_mem_down = shift_mem_down
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.pos_emb = (
AbsolutePositionalEmbedding(emb_dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
@ -1140,7 +1167,7 @@ class TransformerWrapper(nn.Module):
return_attn=False,
mems=None,
use_cache=False,
**kwargs
**kwargs,
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
@ -1150,7 +1177,7 @@ class TransformerWrapper(nn.Module):
x = self.project_emb(x)
if num_mem > 0:
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
mem = repeat(self.memory_tokens, "n d -> b n d", b=b)
x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens
@ -1158,7 +1185,7 @@ class TransformerWrapper(nn.Module):
mask = F.pad(mask, (num_mem, 0), value=True)
if self.shift_mem_down and exists(mems):
mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :]
mems = [*mems_r, *mems_l]
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
@ -1186,25 +1213,20 @@ class TransformerWrapper(nn.Module):
class ContinuousTransformerWrapper(nn.Module):
def __init__(
self,
*,
max_seq_len,
attn_layers,
dim_in=None,
dim_out=None,
emb_dim=None,
emb_dropout=0.,
use_pos_emb=True
self, *, max_seq_len, attn_layers, dim_in=None, dim_out=None, emb_dim=None, emb_dropout=0.0, use_pos_emb=True
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder"
dim = attn_layers.dim
self.max_seq_len = max_seq_len
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.pos_emb = (
AbsolutePositionalEmbedding(dim, max_seq_len)
if (use_pos_emb and not attn_layers.has_pos_emb)
else always(0)
)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
@ -1214,16 +1236,7 @@ class ContinuousTransformerWrapper(nn.Module):
self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_attn=False,
mems=None,
use_cache=False,
**kwargs
):
def forward(self, x, return_embeddings=False, mask=None, return_attn=False, mems=None, use_cache=False, **kwargs):
b, n, _, device = *x.shape, x.device
x = self.project_in(x)
@ -1245,4 +1258,3 @@ class ContinuousTransformerWrapper(nn.Module):
if len(res) > 1:
return tuple(res)
return res[0]

View File

@ -1,31 +0,0 @@
from torch import nn
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock
class FramePriorNet(nn.Module):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, num_res_blocks=13, num_conv_blocks=2
):
super().__init__()
self.res_blocks = nn.ModuleList()
for idx in range(num_res_blocks):
block = Conv1dBNBlock(
in_channels if idx == 0 else hidden_channels,
out_channels if (idx + 1) == num_res_blocks else hidden_channels,
hidden_channels,
kernel_size,
1,
num_conv_blocks,
)
self.res_blocks.append(block)
def forward(self, x, x_mask=None):
if x_mask is None:
x_mask = 1.0
o = x * x_mask
for block in self.res_blocks:
res = o
o = block(o)
o = o + res
if x_mask is not None:
o = o * x_mask
return o

View File

@ -1,91 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, num_mel, filter):
super().__init__()
self.num_mel = num_mel
start_index = 2
end_index = filter / 16
i = start_index
filt_len = []
while i <= end_index:
i = i * 2
filt_len.append(i)
filt_len.append(i)
filters = [1] + filt_len
num_layers = len(filters) - 1
convs = [
nn.Conv2d(
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
)
for i in range(num_layers)
]
self.convs = nn.ModuleList(convs)
self.training = False
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
self.recurrence = nn.LSTM(
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
)
def forward(self, inputs, input_lengths):
batch_size = inputs.size(0)
x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel]
valid_lengths = input_lengths.float() # [batch_size]
for conv, bn in zip(self.convs, self.bns):
x = conv(x)
x = bn(x)
x = F.relu(x)
# Create the post conv width mask based on the valid lengths of the output of the convolution.
# The valid lengths for the output of a convolution on varying length inputs is
# ceil(input_length/stride) + 1 for stride=3 and padding=2
# For example (kernel_size=3, stride=2, padding=2):
# 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d
# _____
# x _____
# x _____
# x ____
# x
# x x x x -> Output valid length = 4
# Since every example in te batch is zero padded and therefore have separate valid_lengths,
# we need to mask off all the values AFTER the valid length for each example in the batch.
# Otherwise, the convolutions create noise and a lot of not real information
valid_lengths = (valid_lengths / 2).float()
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size]
post_conv_max_width = x.size(2)
mask = torch.arange(post_conv_max_width).to(inputs.device).expand(
len(valid_lengths), post_conv_max_width
) < valid_lengths.unsqueeze(1)
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1]
x = x * mask
x = x.transpose(1, 2)
# x: 4D tensor [batch_size, post_conv_width,
# num_channels==128, post_conv_height]
post_conv_width = x.size(1)
x = x.contiguous().view(batch_size, post_conv_width, -1)
# x: 3D tensor [batch_size, post_conv_width,
# num_channels*post_conv_height]
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
post_conv_input_lengths = valid_lengths
packed_seqs = nn.utils.rnn.pack_padded_sequence(
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.recurrence.flatten_parameters()
_, (ht, _) = self.recurrence(packed_seqs)
last_output = ht[-1]
return last_output.to(inputs.device) # [B, 128]

View File

@ -1,67 +0,0 @@
import torch
from torch.autograd import Function
class VectorQuantization(Function):
@staticmethod
def forward(ctx, inputs, codebook):
with torch.no_grad():
embedding_size = codebook.size(1)
inputs_size = inputs.size()
inputs_flatten = inputs.view(-1, embedding_size)
codebook_sqr = torch.sum(codebook ** 2, dim=1)
inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)
# Compute the distances to the codebook
distances = torch.addmm(codebook_sqr + inputs_sqr,
inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
_, indices_flatten = torch.min(distances, dim=1)
indices = indices_flatten.view(*inputs_size[:-1])
ctx.mark_non_differentiable(indices)
return indices
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError('Trying to call `.grad()` on graph containing '
'`VectorQuantization`. The function `VectorQuantization` '
'is not differentiable. Use `VectorQuantizationStraightThrough` '
'if you want a straight-through estimator of the gradient.')
class VectorQuantizationStraightThrough(Function):
@staticmethod
def forward(ctx, inputs, codebook):
indices = vq(inputs, codebook)
indices_flatten = indices.view(-1)
ctx.save_for_backward(indices_flatten, codebook)
ctx.mark_non_differentiable(indices_flatten)
codes_flatten = torch.index_select(codebook, dim=0,
index=indices_flatten)
codes = codes_flatten.view_as(inputs)
return (codes, indices_flatten)
@staticmethod
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
# Straight-through estimator
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
embedding_size = codebook.size(1)
grad_output_flatten = (grad_output.contiguous()
.view(-1, embedding_size))
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output_flatten)
return (grad_inputs, grad_codebook)
vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply
__all__ = [vq, vq_st]

View File

@ -1,305 +0,0 @@
import copy
from abc import abstractmethod
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
from torch import nn
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2"""
def __init__(
self,
config: "TacotronConfig",
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields as class attributes
for key in config:
setattr(self, key, config[key])
# layers
self.embedding = None
self.encoder = None
self.decoder = None
self.postnet = None
# init tensors
self.embedded_speakers = None
self.embedded_speakers_projected = None
# global style token
if self.gst and self.use_gst:
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
self.gst_layer = None
# Capacitron
if self.capacitron_vae and self.use_capacitron_vae:
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim
self.capacitron_vae_layer = None
# additional layers
self.decoder_backward = None
self.coarse_decoder = None
@staticmethod
def _format_aux_input(aux_input: Dict) -> Dict:
"""Set missing fields to their default values"""
if aux_input:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
return None
#############################
# INIT FUNCTIONS
#############################
def _init_backward_decoder(self):
"""Init the backward decoder for Forward-Backward decoding."""
self.decoder_backward = copy.deepcopy(self.decoder)
def _init_coarse_decoder(self):
"""Init the coarse decoder for Double-Decoder Consistency."""
self.coarse_decoder = copy.deepcopy(self.decoder)
self.coarse_decoder.r_init = self.ddc_r
self.coarse_decoder.set_r(self.ddc_r)
#############################
# CORE FUNCTIONS
#############################
@abstractmethod
def forward(self):
pass
@abstractmethod
def inference(self):
pass
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load model checkpoint and set up internals.
Args:
config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file.
eval (bool, optional): whether to load model for evaluation.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
# TODO: set r in run-time by taking it from the new config
if "r" in state:
# set r from the state (for compatibility with older checkpoints)
self.decoder.set_r(state["r"])
elif "config" in state:
# set r from config used at training time (for inference)
self.decoder.set_r(state["config"]["r"])
else:
# set r from the new config (for new-models)
self.decoder.set_r(config.r)
if eval:
self.eval()
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
assert not self.training
def get_criterion(self) -> nn.Module:
"""Get the model criterion used in training."""
return TacotronLoss(self.config)
@staticmethod
def init_from_config(config: Coqpit):
"""Initialize model from config."""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return BaseTacotron(config, ap, tokenizer, speaker_manager)
##########################
# TEST AND LOG FUNCTIONS #
##########################
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Args:
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_test_aux_input()
for idx, sen in enumerate(test_sentences):
outputs_dict = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False
)
return {"figures": test_figures, "audios": test_audios}
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
logger.test_figures(steps, outputs["figures"])
#############################
# COMMON COMPUTE FUNCTIONS
#############################
def compute_masks(self, text_lengths, mel_lengths):
"""Compute masks against sequence paddings."""
# B x T_in_max (boolean)
input_mask = sequence_mask(text_lengths)
output_mask = None
if mel_lengths is not None:
max_len = mel_lengths.max()
r = self.decoder.r
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
output_mask = sequence_mask(mel_lengths, max_len=max_len)
return input_mask, output_mask
def _backward_pass(self, mel_specs, encoder_outputs, mask):
"""Run backwards decoder"""
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
return decoder_outputs_b, alignments_b
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
"""Double Decoder Consistency"""
T = mel_specs.shape[1]
if T % self.coarse_decoder.r > 0:
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
encoder_outputs.detach(), mel_specs, input_mask
)
# scale_factor = self.decoder.r_init / self.decoder.r
alignments_backward = torch.nn.functional.interpolate(
alignments_backward.transpose(1, 2),
size=alignments.shape[1],
mode="nearest",
).transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
return decoder_outputs_backward, alignments_backward
#############################
# EMBEDDING FUNCTIONS
#############################
def compute_gst(self, inputs, style_input, speaker_embedding=None):
"""Compute global style token"""
if isinstance(style_input, dict):
# multiply each style token with a weight
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
# ignore style token and return zero tensor
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
else:
# compute style tokens
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
return inputs
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
"""Capacitron Variational Autoencoder"""
(
VAE_outputs,
posterior_distribution,
prior_distribution,
capacitron_beta,
) = self.capacitron_vae_layer(
reference_mel_info,
text_info,
speaker_embedding, # pylint: disable=not-callable
)
VAE_outputs = VAE_outputs.to(inputs.device)
encoder_output = self._concat_speaker_embedding(
inputs, VAE_outputs
) # concatenate to the output of the basic tacotron encoder
return (
encoder_output,
posterior_distribution,
prior_distribution,
capacitron_beta,
)
@staticmethod
def _add_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
outputs = outputs + embedded_speakers_
return outputs
@staticmethod
def _concat_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
outputs = torch.cat([outputs, embedded_speakers_], dim=-1)
return outputs
#############################
# CALLBACKS
#############################
def on_epoch_start(self, trainer):
"""Callback for setting values wrt gradual training schedule.
Args:
trainer (TrainerTTS): TTS trainer object that is used to train this model.
"""
if self.gradual_training:
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
trainer.config.r = r
self.decoder.set_r(r)
if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}")

View File

@ -2,12 +2,12 @@
import os
import random
from contextlib import contextmanager
from time import time
import torch
import torch.nn.functional as F
import torchaudio
from tqdm import tqdm
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
@ -16,22 +16,14 @@ from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
from TTS.tts.layers.tortoise.clvp import CLVP
from TTS.tts.layers.tortoise.cvvp import CVVP
from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.diffusion import (
SpacedDiffusion,
get_named_beta_schedule,
space_timesteps,
)
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
from contextlib import contextmanager
def pad_or_truncate(t, length):
"""
@ -56,9 +48,7 @@ def load_discrete_vocoder_diffuser(
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
"""
return SpacedDiffusion(
use_timesteps=space_timesteps(
trained_diffusion_steps, [desired_diffusion_steps]
),
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type="epsilon",
model_var_type="learned_range",
loss_type="mse",
@ -141,7 +131,7 @@ def do_spectrogram_diffusion(
output_shape,
noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=verbose
progress=verbose,
)
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
@ -166,9 +156,7 @@ def classify_audio_clip(clip):
kernel_size=5,
distribute_zero_label=False,
)
classifier.load_state_dict(
torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu"))
)
classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu")))
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
return results[0][0]
@ -238,9 +226,7 @@ class TextToSpeech:
self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed
self.models_dir = models_dir
self.autoregressive_batch_size = (
pick_best_batch_size_for_gpu()
if autoregressive_batch_size is None
else autoregressive_batch_size
pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
)
self.enable_redaction = enable_redaction
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -274,9 +260,7 @@ class TextToSpeech:
self.autoregressive.load_state_dict(torch.load(ar_path))
self.autoregressive.post_init_gpt2_config(kv_cache)
diff_path = diff_checkpoint or get_model_path(
"diffusion_decoder.pth", models_dir
)
diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
self.diffusion = (
DiffusionTts(
model_channels=1024,
@ -365,9 +349,7 @@ class TextToSpeech:
.cpu()
.eval()
)
self.cvvp.load_state_dict(
torch.load(get_model_path("cvvp.pth", self.models_dir))
)
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
def get_conditioning_latents(
self,
@ -407,11 +389,7 @@ class TextToSpeech:
DURS_CONST = 102400
for ls in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = (
torchaudio.functional.resample(ls[0], 22050, 24000)
if original_tortoise
else ls[1]
)
sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
if latent_averaging_mode == 0:
sample = pad_or_truncate(sample, DURS_CONST)
cond_mel = wav_to_univnet_mel(
@ -426,9 +404,7 @@ class TextToSpeech:
if latent_averaging_mode == 2:
temp_diffusion_conds = []
for chunk in range(ceil(sample.shape[1] / DURS_CONST)):
current_sample = sample[
:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST
]
current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST]
current_sample = pad_or_truncate(current_sample, DURS_CONST)
cond_mel = wav_to_univnet_mel(
current_sample.to(self.device),
@ -440,9 +416,7 @@ class TextToSpeech:
elif latent_averaging_mode == 2:
temp_diffusion_conds.append(cond_mel)
if latent_averaging_mode == 2:
diffusion_conds.append(
torch.stack(temp_diffusion_conds).mean(0)
)
diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0))
diffusion_conds = torch.stack(diffusion_conds, dim=1)
with self.temporary_cuda(self.diffusion) as diffusion:
@ -471,9 +445,7 @@ class TextToSpeech:
)
)
with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(
torch.tensor([0.0])
)
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
def tts_with_preset(self, text, preset="fast", **kwargs):
"""
@ -521,10 +493,7 @@ class TextToSpeech:
"diffusion_iterations": 50,
"sampler": "ddim",
},
"fast_old": {
"num_autoregressive_samples": 96,
"diffusion_iterations": 80
},
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
"standard": {
"num_autoregressive_samples": 256,
"diffusion_iterations": 200,
@ -618,9 +587,7 @@ class TextToSpeech:
"""
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
text_tokens = (
torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
)
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert (
text_tokens.shape[-1] < 400
@ -628,12 +595,7 @@ class TextToSpeech:
auto_conds = None
if voice_samples is not None:
(
auto_conditioning,
diffusion_conditioning,
auto_conds,
_,
) = self.get_conditioning_latents(
(auto_conditioning, diffusion_conditioning, auto_conds, _,) = self.get_conditioning_latents(
voice_samples,
return_mels=True,
latent_averaging_mode=latent_averaging_mode,
@ -650,10 +612,7 @@ class TextToSpeech:
diffusion_conditioning = diffusion_conditioning.to(self.device)
diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=diffusion_iterations,
cond_free=cond_free,
cond_free_k=cond_free_k,
sampler=sampler
desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler
)
# in the case of single_sample,
@ -664,13 +623,13 @@ class TextToSpeech:
samples = []
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.stop_mel_token
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
calm_token = (
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
)
self.autoregressive = self.autoregressive.to(self.device)
if verbose:
print("Generating autoregressive samples..")
with self.temporary_cuda(
self.autoregressive
) as autoregressive, torch.autocast(
with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=half
):
for b in tqdm(range(num_batches), disable=not verbose):
@ -689,9 +648,7 @@ class TextToSpeech:
padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes)
self.autoregressive_batch_size = (
orig_batch_size # in the case of single_sample
)
self.autoregressive_batch_size = orig_batch_size # in the case of single_sample
clip_results = []
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
@ -729,9 +686,7 @@ class TextToSpeech:
if cvvp_amount == 1:
clip_results.append(cvvp)
else:
clip_results.append(
cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount)
)
clip_results.append(cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount))
else:
clip_results.append(clvp_res)
clip_results = torch.cat(clip_results, dim=0)
@ -744,19 +699,14 @@ class TextToSpeech:
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage.
with self.temporary_cuda(
self.autoregressive
) as autoregressive:
with self.temporary_cuda(self.autoregressive) as autoregressive:
best_latents = autoregressive(
auto_conditioning.repeat(k, 1),
text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
best_results,
torch.tensor(
[
best_results.shape[-1]
* self.autoregressive.mel_length_compression
],
[best_results.shape[-1] * self.autoregressive.mel_length_compression],
device=text_tokens.device,
),
return_latent=True,
@ -778,9 +728,7 @@ class TextToSpeech:
ctokens += 1
else:
ctokens = 0
if (
ctokens > 8
): # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
with self.temporary_cuda(self.diffusion) as diffusion:
@ -801,10 +749,7 @@ class TextToSpeech:
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
return clip
wav_candidates = [
potentially_redact(wav_candidate, text)
for wav_candidate in wav_candidates
]
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
if len(wav_candidates) > 1:
res = wav_candidates

View File

@ -97,7 +97,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
self.normalized=normalized
self.normalized = normalized
if use_mel:
self._build_mel_basis()