style fix

This commit is contained in:
manmay-nakhashi 2023-04-22 17:49:47 +05:30
parent b892aa925a
commit 18c745ceef
22 changed files with 723 additions and 1605 deletions

View File

@ -1,9 +1,8 @@
import os
import functools import functools
import math import math
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
@ -11,6 +10,7 @@ from transformers import LogitsWarper
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
def zero_module(module): def zero_module(module):
""" """
Zero out the parameters of a module and return it. Zero out the parameters of a module and return it.
@ -64,11 +64,11 @@ class QKVAttentionLegacy(nn.Module):
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum( weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
if rel_pos is not None: if rel_pos is not None:
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None: if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
@ -112,7 +112,13 @@ class AttentionBlock(nn.Module):
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings: if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) self.relative_pos_embeddings = RelativePositionBias(
scale=(channels // self.num_heads) ** 0.5,
causal=False,
heads=num_heads,
num_buckets=32,
max_distance=64,
)
else: else:
self.relative_pos_embeddings = None self.relative_pos_embeddings = None
@ -168,9 +174,7 @@ class Downsample(nn.Module):
stride = factor stride = factor
if use_conv: if use_conv:
self.op = nn.Conv1d( self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad)
self.channels, self.out_channels, ksize, stride=stride, padding=pad
)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
@ -182,15 +186,15 @@ class Downsample(nn.Module):
class ResBlock(nn.Module): class ResBlock(nn.Module):
def __init__( def __init__(
self, self,
channels, channels,
dropout, dropout,
out_channels=None, out_channels=None,
use_conv=False, use_conv=False,
use_scale_shift_norm=False, use_scale_shift_norm=False,
up=False, up=False,
down=False, down=False,
kernel_size=3, kernel_size=3,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -221,17 +225,13 @@ class ResBlock(nn.Module):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = nn.Conv1d( self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding)
channels, self.out_channels, kernel_size, padding=padding
)
else: else:
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
@ -249,37 +249,38 @@ class ResBlock(nn.Module):
class AudioMiniEncoder(nn.Module): class AudioMiniEncoder(nn.Module):
def __init__(self, def __init__(
spec_dim, self,
embedding_dim, spec_dim,
base_channels=128, embedding_dim,
depth=2, base_channels=128,
resnet_blocks=2, depth=2,
attn_blocks=4, resnet_blocks=2,
num_attn_heads=4, attn_blocks=4,
dropout=0, num_attn_heads=4,
downsample_factor=2, dropout=0,
kernel_size=3): downsample_factor=2,
kernel_size=3,
):
super().__init__() super().__init__()
self.init = nn.Sequential( self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
)
ch = base_channels ch = base_channels
res = [] res = []
for l in range(depth): for l in range(depth):
for r in range(resnet_blocks): for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor)) res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
ch *= 2 ch *= 2
self.res = nn.Sequential(*res) self.res = nn.Sequential(*res)
self.final = nn.Sequential( self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
normalization(ch),
nn.SiLU(),
nn.Conv1d(ch, embedding_dim, 1)
)
attn = [] attn = []
for a in range(attn_blocks): for a in range(attn_blocks):
attn.append(AttentionBlock(embedding_dim, num_attn_heads,)) attn.append(
AttentionBlock(
embedding_dim,
num_attn_heads,
)
)
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.dim = embedding_dim self.dim = embedding_dim
@ -291,15 +292,24 @@ class AudioMiniEncoder(nn.Module):
return h[:, :, 0] return h[:, :, 0]
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../utils/assets/tortoise/mel_norms.pth') DEFAULT_MEL_NORM_FILE = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/mel_norms.pth"
)
class TorchMelSpectrogram(nn.Module): class TorchMelSpectrogram(nn.Module):
def __init__(self, filter_length=1024, hop_length=256, def __init__(
win_length=1024, n_mel_channels=80, self,
mel_fmin=0, mel_fmax=8000, filter_length=1024,
sampling_rate=22050, normalize=False, hop_length=256,
mel_norm_file=DEFAULT_MEL_NORM_FILE): win_length=1024,
n_mel_channels=80,
mel_fmin=0,
mel_fmax=8000,
sampling_rate=22050,
normalize=False,
mel_norm_file=DEFAULT_MEL_NORM_FILE,
):
super().__init__() super().__init__()
# These are the default tacotron values for the MEL spectrogram. # These are the default tacotron values for the MEL spectrogram.
self.filter_length = filter_length self.filter_length = filter_length
@ -309,16 +319,18 @@ class TorchMelSpectrogram(nn.Module):
self.mel_fmin = mel_fmin self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax self.mel_fmax = mel_fmax
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, self.mel_stft = torchaudio.transforms.MelSpectrogram(
hop_length=self.hop_length, n_fft=self.filter_length,
win_length=self.win_length, hop_length=self.hop_length,
power=2, win_length=self.win_length,
normalized=normalize, power=2,
sample_rate=self.sampling_rate, normalized=normalize,
f_min=self.mel_fmin, sample_rate=self.sampling_rate,
f_max=self.mel_fmax, f_min=self.mel_fmin,
n_mels=self.n_mel_channels, f_max=self.mel_fmax,
norm="slaney") n_mels=self.n_mel_channels,
norm="slaney",
)
self.mel_norm_file = mel_norm_file self.mel_norm_file = mel_norm_file
if self.mel_norm_file is not None: if self.mel_norm_file is not None:
self.mel_norms = torch.load(self.mel_norm_file) self.mel_norms = torch.load(self.mel_norm_file)
@ -326,7 +338,9 @@ class TorchMelSpectrogram(nn.Module):
self.mel_norms = None self.mel_norms = None
def forward(self, inp): def forward(self, inp):
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) if (
len(inp.shape) == 3
): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
inp = inp.squeeze(1) inp = inp.squeeze(1)
assert len(inp.shape) == 2 assert len(inp.shape) == 2
self.mel_stft = self.mel_stft.to(inp.device) self.mel_stft = self.mel_stft.to(inp.device)
@ -344,6 +358,7 @@ class CheckpointedLayer(nn.Module):
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
checkpoint for all other args. checkpoint for all other args.
""" """
def __init__(self, wrap): def __init__(self, wrap):
super().__init__() super().__init__()
self.wrap = wrap self.wrap = wrap
@ -360,6 +375,7 @@ class CheckpointedXTransformerEncoder(nn.Module):
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
to channels-last that XTransformer expects. to channels-last that XTransformer expects.
""" """
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
super().__init__() super().__init__()
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
@ -374,10 +390,10 @@ class CheckpointedXTransformerEncoder(nn.Module):
def forward(self, x, **kwargs): def forward(self, x, **kwargs):
if self.needs_permute: if self.needs_permute:
x = x.permute(0,2,1) x = x.permute(0, 2, 1)
h = self.transformer(x, **kwargs) h = self.transformer(x, **kwargs)
if self.exit_permute: if self.exit_permute:
h = h.permute(0,2,1) h = h.permute(0, 2, 1)
return h return h
@ -392,9 +408,7 @@ class TypicalLogitsWarper(LogitsWarper):
self.mass = mass self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep self.min_tokens_to_keep = min_tokens_to_keep
def __call__( def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
# calculate entropy # calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1) normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized) p = torch.exp(normalized)
@ -409,15 +423,11 @@ class TypicalLogitsWarper(LogitsWarper):
# Remove tokens with cumulative mass above the threshold # Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1) last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0 last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather( sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
1, last_ind.view(-1, 1)
)
if self.min_tokens_to_keep > 1: if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter( indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, self.filter_value) scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores return scores

View File

@ -7,11 +7,10 @@ import numpy as np
import torch import torch
import torchaudio import torchaudio
from scipy.io.wavfile import read from scipy.io.wavfile import read
from TTS.utils.audio.torch_transforms import TorchSTFT from TTS.utils.audio.torch_transforms import TorchSTFT
BUILTIN_VOICES_DIR = os.path.join( BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices")
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices"
)
def load_wav_to_torch(full_path): def load_wav_to_torch(full_path):
@ -58,10 +57,7 @@ def read_audio_file(audiopath: str):
def load_required_audio(audiopath: str): def load_required_audio(audiopath: str):
audio, lsr = read_audio_file(audiopath) audio, lsr = read_audio_file(audiopath)
audios = [ audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)]
torchaudio.functional.resample(audio, lsr, sampling_rate)
for sampling_rate in (22050, 24000)
]
for audio in audios: for audio in audios:
check_audio(audio, audiopath) check_audio(audio, audiopath)
@ -83,9 +79,7 @@ TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(norm_mel): def denormalize_tacotron_mel(norm_mel):
return ((norm_mel + 1) / 2) * ( return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN
TACOTRON_MEL_MAX - TACOTRON_MEL_MIN
) + TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel): def normalize_tacotron_mel(mel):
@ -118,11 +112,7 @@ def get_voices(extra_voice_dirs: List[str] = []):
for sub in subs: for sub in subs:
subj = os.path.join(d, sub) subj = os.path.join(d, sub)
if os.path.isdir(subj): if os.path.isdir(subj):
voices[sub] = ( voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth"))
list(glob(f"{subj}/*.wav"))
+ list(glob(f"{subj}/*.mp3"))
+ list(glob(f"{subj}/*.pth"))
)
return voices return voices
@ -148,9 +138,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
for voice in voices: for voice in voices:
if voice == "random": if voice == "random":
if len(voices) > 1: if len(voices) > 1:
print( print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
"Cannot combine a random voice with a non-random voice. Just using a random voice."
)
return None, None return None, None
clip, latent = load_voice(voice, extra_voice_dirs) clip, latent = load_voice(voice, extra_voice_dirs)
if latent is None: if latent is None:
@ -171,18 +159,21 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
latents = (latents_0, latents_1) latents = (latents_0, latents_1)
return None, latents return None, latents
def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"): def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
stft = TorchSTFT(n_fft=1024, stft = TorchSTFT(
hop_length=256, n_fft=1024,
win_length=1024, hop_length=256,
use_mel=True, win_length=1024,
n_mels=100, use_mel=True,
sample_rate=24000, n_mels=100,
mel_fmin=0, sample_rate=24000,
mel_fmax=12000) mel_fmin=0,
mel_fmax=12000,
)
stft = stft.to(device) stft = stft.to(device)
mel = stft(wav) mel = stft(wav)
mel = dynamic_range_compression(mel) mel = dynamic_range_compression(mel)
if do_normalization: if do_normalization:
mel = normalize_tacotron_mel(mel) mel = normalize_tacotron_mel(mel)
return mel return mel

View File

@ -9,6 +9,7 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
def null_position_embeddings(range, dim): def null_position_embeddings(range, dim):
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
@ -98,9 +99,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
assert self.cached_mel_emb is not None assert self.cached_mel_emb is not None
assert inputs_embeds is None # Not supported by this inference model. assert inputs_embeds is None # Not supported by this inference model.
assert labels is None # Training not supported by this inference model. assert labels is None # Training not supported by this inference model.
return_dict = ( return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict if return_dict is not None else self.config.use_return_dict
)
# Create embedding # Create embedding
mel_len = self.cached_mel_emb.shape[1] mel_len = self.cached_mel_emb.shape[1]
@ -109,9 +108,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
text_emb = self.embeddings(text_inputs) text_emb = self.embeddings(text_inputs)
text_emb = text_emb + self.text_pos_embedding(text_emb) text_emb = text_emb + self.text_pos_embedding(text_emb)
if self.cached_mel_emb.shape[0] != text_emb.shape[0]: if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
mel_emb = self.cached_mel_emb.repeat_interleave( mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0)
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
)
else: # this outcome only occurs once per loop in most cases else: # this outcome only occurs once per loop in most cases
mel_emb = self.cached_mel_emb mel_emb = self.cached_mel_emb
emb = torch.cat([mel_emb, text_emb], dim=1) emb = torch.cat([mel_emb, text_emb], dim=1)
@ -158,10 +155,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
""" """
return tuple( return tuple(
tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past for layer_past in past
) )
@ -210,9 +204,7 @@ class LearnedPositionEmbeddings(nn.Module):
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
def build_hf_gpt_transformer( def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing
):
""" """
GPT-2 implemented by the HuggingFace library. GPT-2 implemented by the HuggingFace library.
""" """
@ -230,9 +222,7 @@ def build_hf_gpt_transformer(
) )
gpt = GPT2Model(gpt_config) gpt = GPT2Model(gpt_config)
# Override the built in positional embeddings # Override the built in positional embeddings
del ( del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
gpt.wpe
) # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
# Built-in token embeddings are unused. # Built-in token embeddings are unused.
del gpt.wte del gpt.wte
@ -251,21 +241,15 @@ class MelEncoder(nn.Module):
self.channels = channels self.channels = channels
self.encoder = nn.Sequential( self.encoder = nn.Sequential(
nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
nn.Sequential( nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]
),
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 16, channels // 2), nn.GroupNorm(channels // 16, channels // 2),
nn.ReLU(), nn.ReLU(),
nn.Sequential( nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]
),
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
nn.GroupNorm(channels // 8, channels), nn.GroupNorm(channels // 8, channels),
nn.ReLU(), nn.ReLU(),
nn.Sequential( nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]
),
) )
self.reduction = 4 self.reduction = 4
@ -317,9 +301,7 @@ class UnifiedVoice(nn.Module):
super().__init__() super().__init__()
self.number_text_tokens = number_text_tokens self.number_text_tokens = number_text_tokens
self.start_text_token = ( self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
number_text_tokens * types if start_text_token is None else start_text_token
)
self.stop_text_token = 0 self.stop_text_token = 0
self.number_mel_codes = number_mel_codes self.number_mel_codes = number_mel_codes
self.start_mel_token = start_mel_token self.start_mel_token = start_mel_token
@ -331,12 +313,8 @@ class UnifiedVoice(nn.Module):
self.model_dim = model_dim self.model_dim = model_dim
self.max_conditioning_inputs = max_conditioning_inputs self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression self.mel_length_compression = mel_length_compression
self.conditioning_encoder = ConditioningEncoder( self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
80, model_dim, num_attn_heads=heads self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
)
self.text_embedding = nn.Embedding(
self.number_text_tokens * types + 1, model_dim
)
if use_mel_codes_as_input: if use_mel_codes_as_input:
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
else: else:
@ -356,12 +334,8 @@ class UnifiedVoice(nn.Module):
checkpointing, checkpointing,
) )
if train_solo_embeddings: if train_solo_embeddings:
self.mel_solo_embedding = nn.Parameter( self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
torch.randn(1, 1, model_dim) * 0.02, requires_grad=True self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
)
self.text_solo_embedding = nn.Parameter(
torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
)
else: else:
self.mel_solo_embedding = 0 self.mel_solo_embedding = 0
self.text_solo_embedding = 0 self.text_solo_embedding = 0
@ -414,9 +388,7 @@ class UnifiedVoice(nn.Module):
preformatting to create a working TTS model. preformatting to create a working TTS model.
""" """
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>). # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
mel_lengths = torch.div( mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc")
wav_lengths, self.mel_length_compression, rounding_mode="trunc"
)
for b in range(len(mel_lengths)): for b in range(len(mel_lengths)):
actual_end = ( actual_end = (
mel_lengths[b] + 1 mel_lengths[b] + 1
@ -436,31 +408,22 @@ class UnifiedVoice(nn.Module):
return_latent=False, return_latent=False,
): ):
if second_inputs is not None: if second_inputs is not None:
emb = torch.cat( emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
[speech_conditioning_inputs, first_inputs, second_inputs], dim=1
)
else: else:
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
gpt_out = self.gpt( gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
inputs_embeds=emb, return_dict=True, output_attentions=get_attns
)
if get_attns: if get_attns:
return gpt_out.attentions return gpt_out.attentions
enc = gpt_out.last_hidden_state[ enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
:, 1:
] # The first logit is tied to the speech_conditioning_input
enc = self.final_norm(enc) enc = self.final_norm(enc)
if return_latent: if return_latent:
return ( return (
enc[ enc[
:, :,
speech_conditioning_inputs.shape[ speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1],
1
] : speech_conditioning_inputs.shape[1]
+ first_inputs.shape[1],
], ],
enc[:, -second_inputs.shape[1] :], enc[:, -second_inputs.shape[1] :],
) )
@ -539,9 +502,7 @@ class UnifiedVoice(nn.Module):
text_inputs, text_targets = self.build_aligned_inputs_and_targets( text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token text_inputs, self.start_text_token, self.stop_text_token
) )
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding( text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
text_inputs
)
mel_codes, mel_targets = self.build_aligned_inputs_and_targets( mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
mel_codes, self.start_mel_token, self.stop_mel_token mel_codes, self.start_mel_token, self.stop_mel_token
) )
@ -596,15 +557,13 @@ class UnifiedVoice(nn.Module):
max_generate_length=None, max_generate_length=None,
typical_sampling=False, typical_sampling=False,
typical_mass=0.9, typical_mass=0.9,
**hf_generate_kwargs **hf_generate_kwargs,
): ):
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
text_inputs, text_targets = self.build_aligned_inputs_and_targets( text_inputs, text_targets = self.build_aligned_inputs_and_targets(
text_inputs, self.start_text_token, self.stop_text_token text_inputs, self.start_text_token, self.stop_text_token
) )
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding( text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
text_inputs
)
conds = speech_conditioning_latent.unsqueeze(1) conds = speech_conditioning_latent.unsqueeze(1)
emb = torch.cat([conds, text_emb], dim=1) emb = torch.cat([conds, text_emb], dim=1)
@ -628,20 +587,14 @@ class UnifiedVoice(nn.Module):
num_return_sequences % input_tokens.shape[0] == 0 num_return_sequences % input_tokens.shape[0] == 0
), "The number of return sequences must be divisible by the number of input sequences" ), "The number of return sequences must be divisible by the number of input sequences"
fake_inputs = fake_inputs.repeat(num_return_sequences, 1) fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
input_tokens = input_tokens.repeat( input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
num_return_sequences // input_tokens.shape[0], 1
)
inputs = torch.cat([fake_inputs, input_tokens], dim=1) inputs = torch.cat([fake_inputs, input_tokens], dim=1)
logits_processor = ( logits_processor = (
LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
if typical_sampling
else LogitsProcessorList()
) # TODO disable this ) # TODO disable this
max_length = ( max_length = (
trunc_index + self.max_mel_tokens - 1 trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
if max_generate_length is None
else trunc_index + max_generate_length
) )
gen = self.inference_model.generate( gen = self.inference_model.generate(
inputs, inputs,
@ -651,7 +604,7 @@ class UnifiedVoice(nn.Module):
max_length=max_length, max_length=max_length,
logits_processor=logits_processor, logits_processor=logits_processor,
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
**hf_generate_kwargs **hf_generate_kwargs,
) )
return gen[:, trunc_index:] return gen[:, trunc_index:]

View File

@ -1,13 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from TTS.tts.layers.tortoise.arch_utils import ( from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module
AttentionBlock,
Downsample,
Upsample,
normalization,
zero_module,
)
class ResBlock(nn.Module): class ResBlock(nn.Module):
@ -54,19 +48,13 @@ class ResBlock(nn.Module):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
nn.Conv1d(
self.out_channels, self.out_channels, kernel_size, padding=padding
)
),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = nn.Conv1d( self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding)
dims, channels, self.out_channels, kernel_size, padding=padding
)
else: else:
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
@ -104,24 +92,14 @@ class AudioMiniEncoder(nn.Module):
self.layers = depth self.layers = depth
for l in range(depth): for l in range(depth):
for r in range(resnet_blocks): for r in range(resnet_blocks):
res.append( res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size) res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
)
res.append(
Downsample(
ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
)
)
ch *= 2 ch *= 2
self.res = nn.Sequential(*res) self.res = nn.Sequential(*res)
self.final = nn.Sequential( self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
)
attn = [] attn = []
for a in range(attn_blocks): for a in range(attn_blocks):
attn.append( attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)
)
self.attn = nn.Sequential(*attn) self.attn = nn.Sequential(*attn)
self.dim = embedding_dim self.dim = embedding_dim

View File

@ -12,9 +12,10 @@ def exists(val):
return val is not None return val is not None
def masked_mean(t, mask, dim = 1): def masked_mean(t, mask, dim=1):
t = t.masked_fill(~mask[:, :, None], 0.) t = t.masked_fill(~mask[:, :, None], 0.0)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] return t.sum(dim=1) / mask.sum(dim=1)[..., None]
class CLVP(nn.Module): class CLVP(nn.Module):
""" """
@ -25,23 +26,23 @@ class CLVP(nn.Module):
""" """
def __init__( def __init__(
self, self,
*, *,
dim_text=512, dim_text=512,
dim_speech=512, dim_speech=512,
dim_latent=512, dim_latent=512,
num_text_tokens=256, num_text_tokens=256,
text_enc_depth=6, text_enc_depth=6,
text_seq_len=120, text_seq_len=120,
text_heads=8, text_heads=8,
num_speech_tokens=8192, num_speech_tokens=8192,
speech_enc_depth=6, speech_enc_depth=6,
speech_heads=8, speech_heads=8,
speech_seq_len=250, speech_seq_len=250,
text_mask_percentage=0, text_mask_percentage=0,
voice_mask_percentage=0, voice_mask_percentage=0,
wav_token_compression=1024, wav_token_compression=1024,
use_xformers=False, use_xformers=False,
): ):
super().__init__() super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim_text) self.text_emb = nn.Embedding(num_text_tokens, dim_text)
@ -59,13 +60,14 @@ class CLVP(nn.Module):
dim=dim_text, dim=dim_text,
depth=text_enc_depth, depth=text_enc_depth,
heads=text_heads, heads=text_heads,
ff_dropout=.1, ff_dropout=0.1,
ff_mult=2, ff_mult=2,
attn_dropout=.1, attn_dropout=0.1,
use_rmsnorm=True, use_rmsnorm=True,
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
)) ),
)
self.speech_transformer = CheckpointedXTransformerEncoder( self.speech_transformer = CheckpointedXTransformerEncoder(
needs_permute=False, needs_permute=False,
exit_permute=False, exit_permute=False,
@ -74,20 +76,23 @@ class CLVP(nn.Module):
dim=dim_speech, dim=dim_speech,
depth=speech_enc_depth, depth=speech_enc_depth,
heads=speech_heads, heads=speech_heads,
ff_dropout=.1, ff_dropout=0.1,
ff_mult=2, ff_mult=2,
attn_dropout=.1, attn_dropout=0.1,
use_rmsnorm=True, use_rmsnorm=True,
ff_glu=True, ff_glu=True,
rotary_pos_emb=True, rotary_pos_emb=True,
)) ),
)
else: else:
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, self.text_transformer = Transformer(
heads=text_heads) causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, )
depth=speech_enc_depth, heads=speech_heads) self.speech_transformer = Transformer(
causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads
)
self.temperature = nn.Parameter(torch.tensor(1.)) self.temperature = nn.Parameter(torch.tensor(1.0))
self.text_mask_percentage = text_mask_percentage self.text_mask_percentage = text_mask_percentage
self.voice_mask_percentage = voice_mask_percentage self.voice_mask_percentage = voice_mask_percentage
self.wav_token_compression = wav_token_compression self.wav_token_compression = wav_token_compression
@ -96,12 +101,7 @@ class CLVP(nn.Module):
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
def forward( def forward(self, text, speech_tokens, return_loss=False):
self,
text,
speech_tokens,
return_loss=False
):
b, device = text.shape[0], text.device b, device = text.shape[0], text.device
if self.training: if self.training:
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
@ -131,25 +131,29 @@ class CLVP(nn.Module):
temp = self.temperature.exp() temp = self.temperature.exp()
if not return_loss: if not return_loss:
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp
return sim return sim
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp
labels = torch.arange(b, device=device) labels = torch.arange(b, device=device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss return loss
if __name__ == '__main__': if __name__ == "__main__":
clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2) clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2)
clip(torch.randint(0,256,(2,120)), clip(
torch.tensor([50,100]), torch.randint(0, 256, (2, 120)),
torch.randint(0,8192,(2,250)), torch.tensor([50, 100]),
torch.tensor([101,102]), torch.randint(0, 8192, (2, 250)),
return_loss=True) torch.tensor([101, 102]),
nonloss = clip(torch.randint(0,256,(2,120)), return_loss=True,
torch.tensor([50,100]), )
torch.randint(0,8192,(2,250)), nonloss = clip(
torch.tensor([101,102]), torch.randint(0, 256, (2, 120)),
return_loss=False) torch.tensor([50, 100]),
print(nonloss.shape) torch.randint(0, 8192, (2, 250)),
torch.tensor([101, 102]),
return_loss=False,
)
print(nonloss.shape)

View File

@ -17,16 +17,7 @@ def masked_mean(t, mask):
class CollapsingTransformer(nn.Module): class CollapsingTransformer(nn.Module):
def __init__( def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
self,
model_dim,
output_dims,
heads,
dropout,
depth,
mask_percentage=0,
**encoder_kwargs
):
super().__init__() super().__init__()
self.transformer = ContinuousTransformerWrapper( self.transformer = ContinuousTransformerWrapper(
max_seq_len=-1, max_seq_len=-1,
@ -105,9 +96,7 @@ class CVVP(nn.Module):
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
if mel_codes is None: if mel_codes is None:
self.speech_emb = nn.Conv1d( self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
mel_channels, model_dim, kernel_size=5, padding=2
)
else: else:
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
self.speech_transformer = CollapsingTransformer( self.speech_transformer = CollapsingTransformer(
@ -135,9 +124,7 @@ class CVVP(nn.Module):
enc_speech = self.speech_transformer(speech_emb) enc_speech = self.speech_transformer(speech_emb)
speech_latents = self.to_speech_latent(enc_speech) speech_latents = self.to_speech_latent(enc_speech)
cond_latents, speech_latents = map( cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents))
lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)
)
temp = self.temperature.exp() temp = self.temperature.exp()
if not return_loss: if not return_loss:

View File

@ -13,8 +13,8 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch as th import torch as th
from tqdm import tqdm
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
from tqdm import tqdm
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
@ -38,18 +38,9 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
# Force variances to be Tensors. Broadcasting helps convert scalars to # Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp(). # Tensors, but it does not work for th.exp().
logvar1, logvar2 = [ logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * ( return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x): def approx_standard_normal_cdf(x):
@ -112,9 +103,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
scale = 1000 / num_diffusion_timesteps scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
return np.linspace( return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine": elif schedule_name == "cosine":
return betas_for_alpha_bar( return betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
@ -149,9 +138,9 @@ class ModelMeanType(enum.Enum):
Which type of output the model predicts. Which type of output the model predicts.
""" """
PREVIOUS_X = 'previous_x' # the model predicts x_{t-1} PREVIOUS_X = "previous_x" # the model predicts x_{t-1}
START_X = 'start_x' # the model predicts x_0 START_X = "start_x" # the model predicts x_0
EPSILON = 'epsilon' # the model predicts epsilon EPSILON = "epsilon" # the model predicts epsilon
class ModelVarType(enum.Enum): class ModelVarType(enum.Enum):
@ -162,17 +151,17 @@ class ModelVarType(enum.Enum):
values between FIXED_SMALL and FIXED_LARGE, making its job easier. values between FIXED_SMALL and FIXED_LARGE, making its job easier.
""" """
LEARNED = 'learned' LEARNED = "learned"
FIXED_SMALL = 'fixed_small' FIXED_SMALL = "fixed_small"
FIXED_LARGE = 'fixed_large' FIXED_LARGE = "fixed_large"
LEARNED_RANGE = 'learned_range' LEARNED_RANGE = "learned_range"
class LossType(enum.Enum): class LossType(enum.Enum):
MSE = 'mse' # use raw MSE loss (and KL when learning variances) MSE = "mse" # use raw MSE loss (and KL when learning variances)
RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances) RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances)
KL = 'kl' # use the variational lower-bound KL = "kl" # use the variational lower-bound
RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB
def is_vb(self): def is_vb(self):
return self == LossType.KL or self == LossType.RESCALED_KL return self == LossType.KL or self == LossType.RESCALED_KL
@ -239,22 +228,12 @@ class GaussianDiffusion:
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = ( self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the # log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain. # beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log( self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
np.append(self.posterior_variance[1], self.posterior_variance[1:]) self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
) self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef1 = (
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - self.alphas_cumprod_prev)
* np.sqrt(alphas)
/ (1.0 - self.alphas_cumprod)
)
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):
""" """
@ -264,13 +243,9 @@ class GaussianDiffusion:
:param t: the number of diffusion steps (minus 1). Here, 0 means one step. :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape. :return: A tuple (mean, variance, log_variance), all of x_start's shape.
""" """
mean = ( mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
)
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = _extract_into_tensor( log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance return mean, variance, log_variance
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
@ -289,8 +264,7 @@ class GaussianDiffusion:
assert noise.shape == x_start.shape assert noise.shape == x_start.shape
return ( return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
* noise
) )
def q_posterior_mean_variance(self, x_start, x_t, t): def q_posterior_mean_variance(self, x_start, x_t, t):
@ -306,9 +280,7 @@ class GaussianDiffusion:
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
) )
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor( posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
self.posterior_log_variance_clipped, t, x_t.shape
)
assert ( assert (
posterior_mean.shape[0] posterior_mean.shape[0]
== posterior_variance.shape[0] == posterior_variance.shape[0]
@ -317,9 +289,7 @@ class GaussianDiffusion:
) )
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance( def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
):
""" """
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0. the initial x, x_0.
@ -358,9 +328,7 @@ class GaussianDiffusion:
model_log_variance = model_var_values model_log_variance = model_var_values
model_variance = th.exp(model_log_variance) model_variance = th.exp(model_log_variance)
else: else:
min_log = _extract_into_tensor( min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
self.posterior_log_variance_clipped, t, x.shape
)
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var]. # The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2 frac = (model_var_values + 1) / 2
@ -398,26 +366,18 @@ class GaussianDiffusion:
return x return x
if self.model_mean_type == ModelMeanType.PREVIOUS_X: if self.model_mean_type == ModelMeanType.PREVIOUS_X:
pred_xstart = process_xstart( pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
)
model_mean = model_output model_mean = model_output
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
if self.model_mean_type == ModelMeanType.START_X: if self.model_mean_type == ModelMeanType.START_X:
pred_xstart = process_xstart(model_output) pred_xstart = process_xstart(model_output)
else: else:
pred_xstart = process_xstart( pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
)
model_mean, _, _ = self.q_posterior_mean_variance(
x_start=pred_xstart, x_t=x, t=t
)
else: else:
raise NotImplementedError(self.model_mean_type) raise NotImplementedError(self.model_mean_type)
assert ( assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
)
return { return {
"mean": model_mean, "mean": model_mean,
"variance": model_variance, "variance": model_variance,
@ -436,16 +396,12 @@ class GaussianDiffusion:
assert x_t.shape == xprev.shape assert x_t.shape == xprev.shape
return ( # (xprev - coef2*x_t) / coef1 return ( # (xprev - coef2*x_t) / coef1
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
- _extract_into_tensor( - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
)
* x_t
) )
def _predict_eps_from_xstart(self, x_t, t, pred_xstart): def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return ( return (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
- pred_xstart
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _scale_timesteps(self, t): def _scale_timesteps(self, t):
@ -463,9 +419,7 @@ class GaussianDiffusion:
This uses the conditioning strategy from Sohl-Dickstein et al. (2015). This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
""" """
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
new_mean = ( new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
)
return new_mean return new_mean
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
@ -481,16 +435,13 @@ class GaussianDiffusion:
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
eps = eps - (1 - alpha_bar).sqrt() * cond_fn( eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs)
x, self._scale_timesteps(t), **model_kwargs
)
out = p_mean_var.copy() out = p_mean_var.copy()
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
out["mean"], _, _ = self.q_posterior_mean_variance( out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
x_start=out["pred_xstart"], x_t=x, t=t
)
return out return out
def k_diffusion_sample_loop( def k_diffusion_sample_loop(
self, self,
k_sampler, k_sampler,
@ -512,9 +463,7 @@ class GaussianDiffusion:
def model_split(*args, **kwargs): def model_split(*args, **kwargs):
model_output = model(*args, **kwargs) model_output = model(*args, **kwargs)
model_epsilon, model_var = th.split( model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1)
model_output, model_output.shape[1] // 2, dim=1
)
return model_epsilon, model_var return model_epsilon, model_var
# #
@ -523,9 +472,7 @@ class GaussianDiffusion:
print(th.tensor(self.betas)) print(th.tensor(self.betas))
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas))
""" """
noise_schedule = NoiseScheduleVP( noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4)
schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4
)
def model_fn_prewrap(x, t, *args, **kwargs): def model_fn_prewrap(x, t, *args, **kwargs):
""" """
@ -584,11 +531,10 @@ class GaussianDiffusion:
if self.conditioning_free is not True: if self.conditioning_free is not True:
raise RuntimeError("cond_free must be true") raise RuntimeError("cond_free must be true")
with tqdm(total=self.num_timesteps) as pbar: with tqdm(total=self.num_timesteps) as pbar:
return self.k_diffusion_sample_loop( return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs
)
else: else:
raise RuntimeError("sampler not impl") raise RuntimeError("sampler not impl")
def p_sample( def p_sample(
self, self,
model, model,
@ -625,13 +571,9 @@ class GaussianDiffusion:
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
) )
noise = th.randn_like(x) noise = th.randn_like(x)
nonzero_mask = ( nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
if cond_fn is not None: if cond_fn is not None:
out["mean"] = self.condition_mean( out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
cond_fn, out, x, t, model_kwargs=model_kwargs
)
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]} return {"sample": sample, "pred_xstart": out["pred_xstart"]}
@ -758,20 +700,11 @@ class GaussianDiffusion:
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
sigma = ( sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
eta
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
)
# Equation 12. # Equation 12.
noise = th.randn_like(x) noise = th.randn_like(x)
mean_pred = ( mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
out["pred_xstart"] * th.sqrt(alpha_bar_prev) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
) # no noise when t == 0
sample = mean_pred + nonzero_mask * sigma * noise sample = mean_pred + nonzero_mask * sigma * noise
return {"sample": sample, "pred_xstart": out["pred_xstart"]} return {"sample": sample, "pred_xstart": out["pred_xstart"]}
@ -800,16 +733,12 @@ class GaussianDiffusion:
# Usually our model outputs epsilon, but we re-derive it # Usually our model outputs epsilon, but we re-derive it
# in case we used x_start or x_prev prediction. # in case we used x_start or x_prev prediction.
eps = ( eps = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
- out["pred_xstart"]
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
# Equation 12. reversed # Equation 12. reversed
mean_pred = ( mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ th.sqrt(1 - alpha_bar_next) * eps
)
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
@ -897,9 +826,7 @@ class GaussianDiffusion:
yield out yield out
img = out["sample"] img = out["sample"]
def _vb_terms_bpd( def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
):
""" """
Get a term for the variational lower-bound. Get a term for the variational lower-bound.
@ -910,15 +837,9 @@ class GaussianDiffusion:
- 'output': a shape [N] tensor of NLLs or KLs. - 'output': a shape [N] tensor of NLLs or KLs.
- 'pred_xstart': the x_0 predictions. - 'pred_xstart': the x_0 predictions.
""" """
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
x_start=x_start, x_t=x_t, t=t out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
) kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
out = self.p_mean_variance(
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
)
kl = normal_kl(
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
)
kl = mean_flat(kl) / np.log(2.0) kl = mean_flat(kl) / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood( decoder_nll = -discretized_gaussian_log_likelihood(
@ -969,7 +890,7 @@ class GaussianDiffusion:
model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
if isinstance(model_outputs, tuple): if isinstance(model_outputs, tuple):
model_output = model_outputs[0] model_output = model_outputs[0]
terms['extra_outputs'] = model_outputs[1:] terms["extra_outputs"] = model_outputs[1:]
else: else:
model_output = model_outputs model_output = model_outputs
@ -996,9 +917,7 @@ class GaussianDiffusion:
terms["vb"] *= self.num_timesteps / 1000.0 terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.PREVIOUS_X: if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance( target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0]
x_start=x_start, x_t=x_t, t=t
)[0]
x_start_pred = torch.zeros(x_start) # Not supported. x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X: elif self.model_mean_type == ModelMeanType.START_X:
target = x_start target = x_start
@ -1020,7 +939,9 @@ class GaussianDiffusion:
return terms return terms
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None): def autoregressive_training_losses(
self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None
):
""" """
Compute training losses for a single timestep. Compute training losses for a single timestep.
@ -1068,9 +989,7 @@ class GaussianDiffusion:
terms["vb"] *= self.num_timesteps / 1000.0 terms["vb"] *= self.num_timesteps / 1000.0
if self.model_mean_type == ModelMeanType.PREVIOUS_X: if self.model_mean_type == ModelMeanType.PREVIOUS_X:
target = self.q_posterior_mean_variance( target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0]
x_start=x_start, x_t=x_t, t=t
)[0]
x_start_pred = torch.zeros(x_start) # Not supported. x_start_pred = torch.zeros(x_start) # Not supported.
elif self.model_mean_type == ModelMeanType.START_X: elif self.model_mean_type == ModelMeanType.START_X:
target = x_start target = x_start
@ -1105,9 +1024,7 @@ class GaussianDiffusion:
batch_size = x_start.shape[0] batch_size = x_start.shape[0]
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl( kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
)
return mean_flat(kl_prior) / np.log(2.0) return mean_flat(kl_prior) / np.log(2.0)
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
@ -1183,9 +1100,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
scale = 1000 / num_diffusion_timesteps scale = 1000 / num_diffusion_timesteps
beta_start = scale * 0.0001 beta_start = scale * 0.0001
beta_end = scale * 0.02 beta_end = scale * 0.02
return np.linspace( return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif schedule_name == "cosine": elif schedule_name == "cosine":
return betas_for_alpha_bar( return betas_for_alpha_bar(
num_diffusion_timesteps, num_diffusion_timesteps,
@ -1219,19 +1134,13 @@ class SpacedDiffusion(GaussianDiffusion):
kwargs["betas"] = np.array(new_betas) kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs) super().__init__(**kwargs)
def p_mean_variance( def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses( def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs) return super().training_losses(self._wrap_model(model), *args, **kwargs)
def autoregressive_training_losses( def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
self, model, *args, **kwargs
): # pylint: disable=signature-differs
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs): def condition_mean(self, cond_fn, *args, **kwargs):
@ -1244,9 +1153,7 @@ class SpacedDiffusion(GaussianDiffusion):
if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
return model return model
mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
return mod( return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps)
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
)
def _scale_timesteps(self, t): def _scale_timesteps(self, t):
# Scaling is done by the wrapped model. # Scaling is done by the wrapped model.
@ -1281,9 +1188,7 @@ def space_timesteps(num_timesteps, section_counts):
for i in range(1, num_timesteps): for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count: if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i)) return set(range(0, num_timesteps, i))
raise ValueError( raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")] section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts) size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts) extra = num_timesteps % len(section_counts)
@ -1292,9 +1197,7 @@ def space_timesteps(num_timesteps, section_counts):
for i, section_count in enumerate(section_counts): for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0) size = size_per + (1 if i < extra else 0)
if size < section_count: if size < section_count:
raise ValueError( raise ValueError(f"cannot divide section of {size} steps into {section_count}")
f"cannot divide section of {size} steps into {section_count}"
)
if section_count <= 1: if section_count <= 1:
frac_stride = 1 frac_stride = 1
else: else:
@ -1315,6 +1218,7 @@ class _WrappedModel:
self.timestep_map = timestep_map self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs): def __call__(self, x, ts, **kwargs):
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
new_ts = map_tensor[ts] new_ts = map_tensor[ts]
@ -1323,6 +1227,7 @@ class _WrappedModel:
model_output = self.model(x, new_ts, **kwargs) model_output = self.model(x, new_ts, **kwargs)
return model_output return model_output
class _WrappedAutoregressiveModel: class _WrappedAutoregressiveModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model self.model = model
@ -1337,6 +1242,7 @@ class _WrappedAutoregressiveModel:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps) new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, x0, new_ts, **kwargs) return self.model(x, x0, new_ts, **kwargs)
def _extract_into_tensor(arr, timesteps, broadcast_shape): def _extract_into_tensor(arr, timesteps, broadcast_shape):
""" """
Extract values from a 1-D numpy array for a batch of indices. Extract values from a 1-D numpy array for a batch of indices.
@ -1350,4 +1256,4 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape): while len(res.shape) < len(broadcast_shape):
res = res[..., None] res = res[..., None]
return res.expand(broadcast_shape) return res.expand(broadcast_shape)

View File

@ -29,11 +29,9 @@ def timestep_embedding(timesteps, dim, max_period=10000):
:return: an [N x dim] Tensor of positional embeddings. :return: an [N x dim] Tensor of positional embeddings.
""" """
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
-math.log(max_period) device=timesteps.device
* torch.arange(start=0, end=half, dtype=torch.float32) )
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
@ -98,17 +96,13 @@ class ResBlock(TimestepBlock):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
nn.Conv1d( nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
self.out_channels, self.out_channels, kernel_size, padding=padding
),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
else: else:
self.skip_connection = nn.Conv1d( self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
channels, self.out_channels, eff_kernel, padding=eff_padding
)
def forward(self, x, emb): def forward(self, x, emb):
h = self.in_layers(x) h = self.in_layers(x)
@ -137,9 +131,7 @@ class DiffusionLayer(TimestepBlock):
dims=1, dims=1,
use_scale_shift_norm=True, use_scale_shift_norm=True,
) )
self.attn = AttentionBlock( self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
model_channels, num_heads, relative_pos_embeddings=True
)
def forward(self, x, time_emb): def forward(self, x, time_emb):
y = self.resblk(x, time_emb) y = self.resblk(x, time_emb)
@ -239,16 +231,11 @@ class DiffusionTts(nn.Module):
DiffusionLayer(model_channels, dropout, num_heads), DiffusionLayer(model_channels, dropout, num_heads),
) )
self.integrating_conv = nn.Conv1d( self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1)
model_channels * 2, model_channels, kernel_size=1
)
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)]
DiffusionLayer(model_channels, dropout, num_heads)
for _ in range(num_layers)
]
+ [ + [
ResBlock( ResBlock(
model_channels, model_channels,
@ -275,9 +262,7 @@ class DiffusionTts(nn.Module):
+ list(self.code_converter.parameters()) + list(self.code_converter.parameters())
+ list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters())
+ list(self.latent_conditioner.parameters()), + list(self.latent_conditioner.parameters()),
"timestep_integrator": list( "timestep_integrator": list(self.conditioning_timestep_integrator.parameters())
self.conditioning_timestep_integrator.parameters()
)
+ list(self.integrating_conv.parameters()), + list(self.integrating_conv.parameters()),
"time_embed": list(self.time_embed.parameters()), "time_embed": list(self.time_embed.parameters()),
} }
@ -285,9 +270,7 @@ class DiffusionTts(nn.Module):
def get_conditioning(self, conditioning_input): def get_conditioning(self, conditioning_input):
speech_conditioning_input = ( speech_conditioning_input = (
conditioning_input.unsqueeze(1) conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
if len(conditioning_input.shape) == 3
else conditioning_input
) )
conds = [] conds = []
for j in range(speech_conditioning_input.shape[1]): for j in range(speech_conditioning_input.shape[1]):
@ -313,29 +296,20 @@ class DiffusionTts(nn.Module):
else: else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb) code_emb = self.code_converter(code_emb)
code_emb = self.code_norm(code_emb) * ( code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
1 + cond_scale.unsqueeze(-1)
) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros( unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
(code_emb.shape[0], 1, 1), device=code_emb.device
)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0: if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = ( unconditioned_batches = (
torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage
< self.unconditioned_percentage
) )
code_emb = torch.where( code_emb = torch.where(
unconditioned_batches, unconditioned_batches,
self.unconditioned_embedding.repeat( self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
aligned_conditioning.shape[0], 1, 1
),
code_emb, code_emb,
) )
expanded_code_emb = F.interpolate( expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest")
code_emb, size=expected_seq_len, mode="nearest"
)
if not return_code_pred: if not return_code_pred:
return expanded_code_emb return expanded_code_emb
@ -376,10 +350,7 @@ class DiffusionTts(nn.Module):
unused_params = [] unused_params = []
if conditioning_free: if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend( unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
list(self.code_converter.parameters())
+ list(self.code_embedding.parameters())
)
unused_params.extend(list(self.latent_conditioner.parameters())) unused_params.extend(list(self.latent_conditioner.parameters()))
else: else:
if precomputed_aligned_embeddings is not None: if precomputed_aligned_embeddings is not None:
@ -390,8 +361,7 @@ class DiffusionTts(nn.Module):
) )
if is_latent(aligned_conditioning): if is_latent(aligned_conditioning):
unused_params.extend( unused_params.extend(
list(self.code_converter.parameters()) list(self.code_converter.parameters()) + list(self.code_embedding.parameters())
+ list(self.code_embedding.parameters())
) )
else: else:
unused_params.extend(list(self.latent_conditioner.parameters())) unused_params.extend(list(self.latent_conditioner.parameters()))

View File

@ -1,6 +1,8 @@
import math import math
import torch import torch
class NoiseScheduleVP: class NoiseScheduleVP:
def __init__( def __init__(
self, self,
@ -107,11 +109,7 @@ class NoiseScheduleVP:
log_alphas = 0.5 * torch.log(alphas_cumprod) log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas) self.total_N = len(log_alphas)
self.T = 1.0 self.T = 1.0
self.t_array = ( self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
torch.linspace(0.0, 1.0, self.total_N + 1)[1:]
.reshape((1, -1))
.to(dtype=dtype)
)
self.log_alpha_array = log_alphas.reshape( self.log_alpha_array = log_alphas.reshape(
( (
1, 1,
@ -131,9 +129,7 @@ class NoiseScheduleVP:
/ math.pi / math.pi
- self.cosine_s - self.cosine_s
) )
self.cosine_log_alpha_0 = math.log( self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
)
self.schedule = schedule self.schedule = schedule
if schedule == "cosine": if schedule == "cosine":
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
@ -157,11 +153,7 @@ class NoiseScheduleVP:
elif self.schedule == "cosine": elif self.schedule == "cosine":
def log_alpha_fn(s): def log_alpha_fn(s):
return torch.log( return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
torch.cos(
(s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0
)
)
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t return log_alpha_t
@ -191,17 +183,11 @@ class NoiseScheduleVP:
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
""" """
if self.schedule == "linear": if self.schedule == "linear":
tmp = ( tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
2.0
* (self.beta_1 - self.beta_0)
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
)
Delta = self.beta_0**2 + tmp Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == "discrete": elif self.schedule == "discrete":
log_alpha = -0.5 * torch.logaddexp( log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
torch.zeros((1,)).to(lamb.device), -2.0 * lamb
)
t = interpolate_fn( t = interpolate_fn(
log_alpha.reshape((-1, 1)), log_alpha.reshape((-1, 1)),
torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
@ -345,14 +331,10 @@ def model_wrapper(
if model_type == "noise": if model_type == "noise":
return output return output
elif model_type == "x_start": elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha( alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
t_continuous
), noise_schedule.marginal_std(t_continuous)
return (x - alpha_t * output) / sigma_t return (x - alpha_t * output) / sigma_t
elif model_type == "v": elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha( alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
t_continuous
), noise_schedule.marginal_std(t_continuous)
return alpha_t * output + sigma_t * x return alpha_t * output + sigma_t * x
elif model_type == "score": elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous) sigma_t = noise_schedule.marginal_std(t_continuous)
@ -482,9 +464,7 @@ class DPM_Solver:
p = self.dynamic_thresholding_ratio p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims( s = expand_dims(
torch.maximum( torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)),
s, self.thresholding_max_val * torch.ones_like(s).to(s.device)
),
dims, dims,
) )
x0 = torch.clamp(x0, -s, s) / s x0 = torch.clamp(x0, -s, s) / s
@ -501,9 +481,7 @@ class DPM_Solver:
Return the data prediction model (with corrector). Return the data prediction model (with corrector).
""" """
noise = self.noise_prediction_fn(x, t) noise = self.noise_prediction_fn(x, t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha( alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
t
), self.noise_schedule.marginal_std(t)
x0 = (x - sigma_t * noise) / alpha_t x0 = (x - sigma_t * noise) / alpha_t
if self.correcting_x0_fn is not None: if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0, t) x0 = self.correcting_x0_fn(x0, t)
@ -536,30 +514,20 @@ class DPM_Solver:
if skip_type == "logSNR": if skip_type == "logSNR":
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace( logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps) return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == "time_uniform": elif skip_type == "time_uniform":
return torch.linspace(t_T, t_0, N + 1).to(device) return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == "time_quadratic": elif skip_type == "time_quadratic":
t_order = 2 t_order = 2
t = ( t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
.pow(t_order)
.to(device)
)
return t return t
else: else:
raise ValueError( raise ValueError(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format( "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
skip_type
)
) )
def get_orders_and_timesteps_for_singlestep_solver( def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
self, steps, order, skip_type, t_T, t_0, device
):
""" """
Get the order of each step for sampling by the singlestep DPM-Solver. Get the order of each step for sampling by the singlestep DPM-Solver.
@ -664,9 +632,7 @@ class DPM_Solver:
dims = x.dim() dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff( log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
s
), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t) alpha_t = torch.exp(log_alpha_t)
@ -716,11 +682,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if solver_type not in ["dpmsolver", "taylor"]: if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError( raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
if r1 is None: if r1 is None:
r1 = 0.5 r1 = 0.5
ns = self.noise_schedule ns = self.noise_schedule
@ -766,10 +728,7 @@ class DPM_Solver:
if model_s is None: if model_s is None:
model_s = self.model_fn(x, s) model_s = self.model_fn(x, s)
x_s1 = ( x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
torch.exp(log_alpha_s1 - log_alpha_s) * x
- (sigma_s1 * phi_11) * model_s
)
model_s1 = self.model_fn(x_s1, s1) model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpmsolver": if solver_type == "dpmsolver":
x_t = ( x_t = (
@ -820,11 +779,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if solver_type not in ["dpmsolver", "taylor"]: if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError( raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
if r1 is None: if r1 is None:
r1 = 1.0 / 3.0 r1 = 1.0 / 3.0
if r2 is None: if r2 is None:
@ -901,9 +856,7 @@ class DPM_Solver:
if model_s is None: if model_s is None:
model_s = self.model_fn(x, s) model_s = self.model_fn(x, s)
if model_s1 is None: if model_s1 is None:
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - ( x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
sigma_s1 * phi_11
) * model_s
model_s1 = self.model_fn(x_s1, s1) model_s1 = self.model_fn(x_s1, s1)
x_s2 = ( x_s2 = (
(torch.exp(log_alpha_s2 - log_alpha_s)) * x (torch.exp(log_alpha_s2 - log_alpha_s)) * x
@ -934,9 +887,7 @@ class DPM_Solver:
else: else:
return x_t return x_t
def multistep_dpm_solver_second_update( def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
):
""" """
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
@ -951,11 +902,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if solver_type not in ["dpmsolver", "taylor"]: if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError( raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
solver_type
)
)
ns = self.noise_schedule ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
@ -964,9 +911,7 @@ class DPM_Solver:
ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t), ns.marginal_lambda(t),
) )
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff( log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
t_prev_0
), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t) alpha_t = torch.exp(log_alpha_t)
@ -977,11 +922,7 @@ class DPM_Solver:
if self.algorithm_type == "dpmsolver++": if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h) phi_1 = torch.expm1(-h)
if solver_type == "dpmsolver": if solver_type == "dpmsolver":
x_t = ( x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
- 0.5 * (alpha_t * phi_1) * D1_0
)
elif solver_type == "taylor": elif solver_type == "taylor":
x_t = ( x_t = (
(sigma_t / sigma_prev_0) * x (sigma_t / sigma_prev_0) * x
@ -1004,9 +945,7 @@ class DPM_Solver:
) )
return x_t return x_t
def multistep_dpm_solver_third_update( def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
):
""" """
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
@ -1029,9 +968,7 @@ class DPM_Solver:
ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t), ns.marginal_lambda(t),
) )
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff( log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
t_prev_0
), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t) alpha_t = torch.exp(log_alpha_t)
@ -1093,9 +1030,7 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if order == 1: if order == 1:
return self.dpm_solver_first_update( return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
x, s, t, return_intermediate=return_intermediate
)
elif order == 2: elif order == 2:
return self.singlestep_dpm_solver_second_update( return self.singlestep_dpm_solver_second_update(
x, x,
@ -1118,9 +1053,7 @@ class DPM_Solver:
else: else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def multistep_dpm_solver_update( def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"
):
""" """
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
@ -1136,17 +1069,11 @@ class DPM_Solver:
x_t: A pytorch tensor. The approximated solution at time `t`. x_t: A pytorch tensor. The approximated solution at time `t`.
""" """
if order == 1: if order == 1:
return self.dpm_solver_first_update( return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
x, t_prev_list[-1], t, model_s=model_prev_list[-1]
)
elif order == 2: elif order == 2:
return self.multistep_dpm_solver_second_update( return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
x, model_prev_list, t_prev_list, t, solver_type=solver_type
)
elif order == 3: elif order == 3:
return self.multistep_dpm_solver_third_update( return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
x, model_prev_list, t_prev_list, t, solver_type=solver_type
)
else: else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
@ -1198,9 +1125,7 @@ class DPM_Solver:
return self.dpm_solver_first_update(x, s, t, return_intermediate=True) return self.dpm_solver_first_update(x, s, t, return_intermediate=True)
def higher_update(x, s, t, **kwargs): def higher_update(x, s, t, **kwargs):
return self.singlestep_dpm_solver_second_update( return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
x, s, t, r1=r1, solver_type=solver_type, **kwargs
)
elif order == 3: elif order == 3:
r1, r2 = 1.0 / 3.0, 2.0 / 3.0 r1, r2 = 1.0 / 3.0, 2.0 / 3.0
@ -1211,16 +1136,10 @@ class DPM_Solver:
) )
def higher_update(x, s, t, **kwargs): def higher_update(x, s, t, **kwargs):
return self.singlestep_dpm_solver_third_update( return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
)
else: else:
raise ValueError( raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
"For adaptive step size solver, order must be 2 or 3, got {}".format(
order
)
)
while torch.abs((s - t_0)).mean() > t_err: while torch.abs((s - t_0)).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h) t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t) x_lower, lower_noise_kwargs = lower_update(x, s, t)
@ -1231,9 +1150,7 @@ class DPM_Solver:
) )
def norm_fn(v): def norm_fn(v):
return torch.sqrt( return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
)
E = norm_fn((x_higher - x_lower) / delta).max() E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.0): if torch.all(E <= 1.0):
@ -1259,9 +1176,7 @@ class DPM_Solver:
Returns: Returns:
xt with shape `(t_size, batch_size, *shape)`. xt with shape `(t_size, batch_size, *shape)`.
""" """
alpha_t, sigma_t = self.noise_schedule.marginal_alpha( alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
t
), self.noise_schedule.marginal_std(t)
if noise is None: if noise is None:
noise = torch.randn((t.shape[0], *x.shape), device=x.device) noise = torch.randn((t.shape[0], *x.shape), device=x.device)
x = x.reshape((-1, *x.shape)) x = x.reshape((-1, *x.shape))
@ -1468,9 +1383,7 @@ class DPM_Solver:
) )
elif method == "multistep": elif method == "multistep":
assert steps >= order assert steps >= order
timesteps = self.get_time_steps( timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
)
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# Init the initial values. # Init the initial values.
step = 0 step = 0
@ -1527,10 +1440,7 @@ class DPM_Solver:
model_prev_list[-1] = self.model_fn(x, t) model_prev_list[-1] = self.model_fn(x, t)
elif method in ["singlestep", "singlestep_fixed"]: elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep": if method == "singlestep":
( (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
timesteps_outer,
orders,
) = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps, steps=steps,
order=order, order=order,
skip_type=skip_type, skip_type=skip_type,
@ -1543,9 +1453,7 @@ class DPM_Solver:
orders = [ orders = [
order, order,
] * K ] * K
timesteps_outer = self.get_time_steps( timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
)
for step, order in enumerate(orders): for step, order in enumerate(orders):
s, t = timesteps_outer[step], timesteps_outer[step + 1] s, t = timesteps_outer[step], timesteps_outer[step + 1]
timesteps_inner = self.get_time_steps( timesteps_inner = self.get_time_steps(
@ -1559,9 +1467,7 @@ class DPM_Solver:
h = lambda_inner[-1] - lambda_inner[0] h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update( x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
x, s, t, order, solver_type=solver_type, r1=r1, r2=r2
)
if self.correcting_xt_fn is not None: if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step) x = self.correcting_xt_fn(x, t, step)
if return_intermediate: if return_intermediate:
@ -1613,9 +1519,7 @@ def interpolate_fn(x, xp, yp):
cand_start_idx, cand_start_idx,
), ),
) )
end_idx = torch.where( end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where( start_idx2 = torch.where(
@ -1628,12 +1532,8 @@ def interpolate_fn(x, xp, yp):
), ),
) )
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather( start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2) end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
).squeeze(2)
end_y = torch.gather(
y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand return cand

View File

@ -40,8 +40,7 @@ class RandomLatentConverter(nn.Module):
def __init__(self, channels): def __init__(self, channels):
super().__init__() super().__init__()
self.layers = nn.Sequential( self.layers = nn.Sequential(
*[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels)
nn.Linear(channels, channels)
) )
self.channels = channels self.channels = channels

View File

@ -95,9 +95,7 @@ def _expand_number(m):
elif num % 100 == 0: elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred" return _inflect.number_to_words(num // 100) + " hundred"
else: else:
return _inflect.number_to_words( return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
num, andword="", zero="oh", group=2
).replace(", ", " ")
else: else:
return _inflect.number_to_words(num, andword="") return _inflect.number_to_words(num, andword="")
@ -165,9 +163,7 @@ def lev_distance(s1, s2):
if c1 == c2: if c1 == c2:
distances_.append(distances[i1]) distances_.append(distances[i1])
else: else:
distances_.append( distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
)
distances = distances_ distances = distances_
return distances[-1] return distances[-1]

View File

@ -36,12 +36,8 @@ def route_args(router, args, depth):
for key in matched_keys: for key in matched_keys:
val = args[key] val = args[key]
for depth, ((f_args, g_args), routes) in enumerate( for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
zip(routed_args, router[key]) new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
):
new_f_args, new_g_args = map(
lambda route: ({key: val} if route else {}), routes
)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args return routed_args
@ -217,12 +213,8 @@ class Transformer(nn.Module):
layers.append( layers.append(
nn.ModuleList( nn.ModuleList(
[ [
LayerScale( LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)),
dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm) LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)),
),
LayerScale(
dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)
),
] ]
) )
) )

View File

@ -1,5 +1,7 @@
import os import os
try: import gdown
try:
import gdown
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Sorry, gdown is required in order to download the new BigVGAN vocoder.\n" "Sorry, gdown is required in order to download the new BigVGAN vocoder.\n"
@ -11,9 +13,7 @@ import progressbar
D_STEM = "https://drive.google.com/uc?id=" D_STEM = "https://drive.google.com/uc?id="
DEFAULT_MODELS_DIR = os.path.join( DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
os.path.expanduser("~"), ".cache", "tortoise", "models"
)
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
MODELS = { MODELS = {
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth", "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
@ -30,6 +30,8 @@ MODELS = {
} }
pbar = None pbar = None
def download_models(specific_models=None): def download_models(specific_models=None):
""" """
Call to download all the models that Tortoise uses. Call to download all the models that Tortoise uses.
@ -62,6 +64,7 @@ def download_models(specific_models=None):
request.urlretrieve(url, model_path, show_progress) request.urlretrieve(url, model_path, show_progress)
print("Done.") print("Done.")
def get_model_path(model_name, models_dir=MODELS_DIR): def get_model_path(model_name, models_dir=MODELS_DIR):
""" """
Get path to given model, download it if it doesn't exist. Get path to given model, download it if it doesn't exist.
@ -71,4 +74,4 @@ def get_model_path(model_name, models_dir=MODELS_DIR):
model_path = os.path.join(models_dir, model_name) model_path = os.path.join(models_dir, model_name)
if not os.path.exists(model_path) and models_dir == MODELS_DIR: if not os.path.exists(model_path) and models_dir == MODELS_DIR:
download_models([model_name]) download_models([model_name])
return model_path return model_path

View File

@ -1,12 +1,12 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import json
from enum import Enum
from typing import Optional, Callable
from dataclasses import dataclass
MAX_WAV_VALUE = 32768.0 MAX_WAV_VALUE = 32768.0
@ -40,18 +40,12 @@ class KernelPredictor(torch.nn.Module):
self.conv_kernel_size = conv_kernel_size self.conv_kernel_size = conv_kernel_size
self.conv_layers = conv_layers self.conv_layers = conv_layers
kpnet_kernel_channels = ( kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
) # l_w
kpnet_bias_channels = conv_out_channels * conv_layers # l_b kpnet_bias_channels = conv_out_channels * conv_layers # l_b
self.input_conv = nn.Sequential( self.input_conv = nn.Sequential(
nn.utils.weight_norm( nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
),
getattr(nn, kpnet_nonlinear_activation)(
**kpnet_nonlinear_activation_params
),
) )
self.residual_convs = nn.ModuleList() self.residual_convs = nn.ModuleList()
@ -69,9 +63,7 @@ class KernelPredictor(torch.nn.Module):
bias=True, bias=True,
) )
), ),
getattr(nn, kpnet_nonlinear_activation)( getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
**kpnet_nonlinear_activation_params
),
nn.utils.weight_norm( nn.utils.weight_norm(
nn.Conv1d( nn.Conv1d(
kpnet_hidden_channels, kpnet_hidden_channels,
@ -81,9 +73,7 @@ class KernelPredictor(torch.nn.Module):
bias=True, bias=True,
) )
), ),
getattr(nn, kpnet_nonlinear_activation)( getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
**kpnet_nonlinear_activation_params
),
) )
) )
self.kernel_conv = nn.utils.weight_norm( self.kernel_conv = nn.utils.weight_norm(
@ -252,17 +242,11 @@ class LVCBlock(torch.nn.Module):
""" """
batch, _, in_length = x.shape batch, _, in_length = x.shape
batch, _, out_channels, kernel_size, kernel_length = kernel.shape batch, _, out_channels, kernel_size, kernel_length = kernel.shape
assert in_length == ( assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
kernel_length * hop_size
), "length of (x, kernel) is not matched"
padding = dilation * int((kernel_size - 1) / 2) padding = dilation * int((kernel_size - 1) / 2)
x = F.pad( x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
x, (padding, padding), "constant", 0 x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
) # (batch, in_channels, in_length + 2*padding)
x = x.unfold(
2, hop_size + 2 * padding, hop_size
) # (batch, in_channels, kernel_length, hop_size + 2*padding)
if hop_size < dilation: if hop_size < dilation:
x = F.pad(x, (0, dilation), "constant", 0) x = F.pad(x, (0, dilation), "constant", 0)
@ -270,12 +254,8 @@ class LVCBlock(torch.nn.Module):
3, dilation, dilation 3, dilation, dilation
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
x = x[:, :, :, :, :hop_size] x = x[:, :, :, :, :hop_size]
x = x.transpose( x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
3, 4 x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
x = x.unfold(
4, kernel_size, 1
) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
o = torch.einsum("bildsk,biokl->bolsd", x, kernel) o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
o = o.to(memory_format=torch.channels_last_3d) o = o.to(memory_format=torch.channels_last_3d)
@ -334,15 +314,11 @@ class UnivNetGenerator(nn.Module):
) )
) )
self.conv_pre = nn.utils.weight_norm( self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
)
self.conv_post = nn.Sequential( self.conv_post = nn.Sequential(
nn.LeakyReLU(lReLU_slope), nn.LeakyReLU(lReLU_slope),
nn.utils.weight_norm( nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")
),
nn.Tanh(), nn.Tanh(),
) )
@ -399,12 +375,16 @@ class VocType:
constructor: Callable[[], nn.Module] constructor: Callable[[], nn.Module]
model_path: str model_path: str
subkey: Optional[str] = None subkey: Optional[str] = None
def optionally_index(self, model_dict): def optionally_index(self, model_dict):
if self.subkey is not None: if self.subkey is not None:
return model_dict[self.subkey] return model_dict[self.subkey]
return model_dict return model_dict
class VocConf(Enum): class VocConf(Enum):
Univnet = VocType(UnivNetGenerator, "vocoder.pth", 'model_g') Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g")
if __name__ == "__main__": if __name__ == "__main__":
model = UnivNetGenerator() model = UnivNetGenerator()

View File

@ -12,9 +12,7 @@ def max_alignment(s1, s2, skip_character="~", record=None):
""" """
if record is None: if record is None:
record = {} record = {}
assert ( assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
skip_character not in s1
), f"Found the skip character {skip_character} in the provided string, {s1}"
if len(s1) == 0: if len(s1) == 0:
return "" return ""
if len(s2) == 0: if len(s2) == 0:
@ -49,15 +47,9 @@ class Wav2VecAlignment:
""" """
def __init__(self, device="cuda"): def __init__(self, device="cuda"):
self.model = Wav2Vec2ForCTC.from_pretrained( self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
"jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli" self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h")
).cpu() self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols")
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"facebook/wav2vec2-large-960h"
)
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
"jbetker/tacotron-symbols"
)
self.device = device self.device = device
def align(self, audio, expected_text, audio_sample_rate=24000): def align(self, audio, expected_text, audio_sample_rate=24000):
@ -117,9 +109,7 @@ class Wav2VecAlignment:
) )
# Now fix up alignments. Anything with -1 should be interpolated. # Now fix up alignments. Anything with -1 should be interpolated.
alignments.append( alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
orig_len
) # This'll get removed but makes the algorithm below more readable.
for i in range(len(alignments)): for i in range(len(alignments)):
if alignments[i] == -1: if alignments[i] == -1:
for j in range(i + 1, len(alignments)): for j in range(i + 1, len(alignments)):
@ -128,9 +118,7 @@ class Wav2VecAlignment:
break break
for j in range(i, next_found_token): for j in range(i, next_found_token):
gap = alignments[next_found_token] - alignments[i - 1] gap = alignments[next_found_token] - alignments[i - 1]
alignments[j] = (j - i + 1) * gap // ( alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1]
next_found_token - i + 1
) + alignments[i - 1]
return alignments[:-1] return alignments[:-1]
@ -140,9 +128,7 @@ class Wav2VecAlignment:
splitted = expected_text.split("[") splitted = expected_text.split("[")
fully_split = [splitted[0]] fully_split = [splitted[0]]
for spl in splitted[1:]: for spl in splitted[1:]:
assert ( assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.'
"]" in spl
), 'Every "[" character must be paired with a "]" with no nesting.'
fully_split.extend(spl.split("]")) fully_split.extend(spl.split("]"))
# At this point, fully_split is a list of strings, with every other string being something that should be redacted. # At this point, fully_split is a list of strings, with every other string being something that should be redacted.

File diff suppressed because it is too large Load Diff

View File

@ -1,31 +0,0 @@
from torch import nn
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock
class FramePriorNet(nn.Module):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, num_res_blocks=13, num_conv_blocks=2
):
super().__init__()
self.res_blocks = nn.ModuleList()
for idx in range(num_res_blocks):
block = Conv1dBNBlock(
in_channels if idx == 0 else hidden_channels,
out_channels if (idx + 1) == num_res_blocks else hidden_channels,
hidden_channels,
kernel_size,
1,
num_conv_blocks,
)
self.res_blocks.append(block)
def forward(self, x, x_mask=None):
if x_mask is None:
x_mask = 1.0
o = x * x_mask
for block in self.res_blocks:
res = o
o = block(o)
o = o + res
if x_mask is not None:
o = o * x_mask
return o

View File

@ -1,91 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, num_mel, filter):
super().__init__()
self.num_mel = num_mel
start_index = 2
end_index = filter / 16
i = start_index
filt_len = []
while i <= end_index:
i = i * 2
filt_len.append(i)
filt_len.append(i)
filters = [1] + filt_len
num_layers = len(filters) - 1
convs = [
nn.Conv2d(
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
)
for i in range(num_layers)
]
self.convs = nn.ModuleList(convs)
self.training = False
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
self.recurrence = nn.LSTM(
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
)
def forward(self, inputs, input_lengths):
batch_size = inputs.size(0)
x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel]
valid_lengths = input_lengths.float() # [batch_size]
for conv, bn in zip(self.convs, self.bns):
x = conv(x)
x = bn(x)
x = F.relu(x)
# Create the post conv width mask based on the valid lengths of the output of the convolution.
# The valid lengths for the output of a convolution on varying length inputs is
# ceil(input_length/stride) + 1 for stride=3 and padding=2
# For example (kernel_size=3, stride=2, padding=2):
# 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d
# _____
# x _____
# x _____
# x ____
# x
# x x x x -> Output valid length = 4
# Since every example in te batch is zero padded and therefore have separate valid_lengths,
# we need to mask off all the values AFTER the valid length for each example in the batch.
# Otherwise, the convolutions create noise and a lot of not real information
valid_lengths = (valid_lengths / 2).float()
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size]
post_conv_max_width = x.size(2)
mask = torch.arange(post_conv_max_width).to(inputs.device).expand(
len(valid_lengths), post_conv_max_width
) < valid_lengths.unsqueeze(1)
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1]
x = x * mask
x = x.transpose(1, 2)
# x: 4D tensor [batch_size, post_conv_width,
# num_channels==128, post_conv_height]
post_conv_width = x.size(1)
x = x.contiguous().view(batch_size, post_conv_width, -1)
# x: 3D tensor [batch_size, post_conv_width,
# num_channels*post_conv_height]
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
post_conv_input_lengths = valid_lengths
packed_seqs = nn.utils.rnn.pack_padded_sequence(
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.recurrence.flatten_parameters()
_, (ht, _) = self.recurrence(packed_seqs)
last_output = ht[-1]
return last_output.to(inputs.device) # [B, 128]

View File

@ -1,67 +0,0 @@
import torch
from torch.autograd import Function
class VectorQuantization(Function):
@staticmethod
def forward(ctx, inputs, codebook):
with torch.no_grad():
embedding_size = codebook.size(1)
inputs_size = inputs.size()
inputs_flatten = inputs.view(-1, embedding_size)
codebook_sqr = torch.sum(codebook ** 2, dim=1)
inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)
# Compute the distances to the codebook
distances = torch.addmm(codebook_sqr + inputs_sqr,
inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
_, indices_flatten = torch.min(distances, dim=1)
indices = indices_flatten.view(*inputs_size[:-1])
ctx.mark_non_differentiable(indices)
return indices
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError('Trying to call `.grad()` on graph containing '
'`VectorQuantization`. The function `VectorQuantization` '
'is not differentiable. Use `VectorQuantizationStraightThrough` '
'if you want a straight-through estimator of the gradient.')
class VectorQuantizationStraightThrough(Function):
@staticmethod
def forward(ctx, inputs, codebook):
indices = vq(inputs, codebook)
indices_flatten = indices.view(-1)
ctx.save_for_backward(indices_flatten, codebook)
ctx.mark_non_differentiable(indices_flatten)
codes_flatten = torch.index_select(codebook, dim=0,
index=indices_flatten)
codes = codes_flatten.view_as(inputs)
return (codes, indices_flatten)
@staticmethod
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
# Straight-through estimator
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
embedding_size = codebook.size(1)
grad_output_flatten = (grad_output.contiguous()
.view(-1, embedding_size))
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output_flatten)
return (grad_inputs, grad_codebook)
vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply
__all__ = [vq, vq_st]

View File

@ -1,305 +0,0 @@
import copy
from abc import abstractmethod
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
from torch import nn
from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
class BaseTacotron(BaseTTS):
"""Base class shared by Tacotron and Tacotron2"""
def __init__(
self,
config: "TacotronConfig",
ap: "AudioProcessor",
tokenizer: "TTSTokenizer",
speaker_manager: SpeakerManager = None,
):
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields as class attributes
for key in config:
setattr(self, key, config[key])
# layers
self.embedding = None
self.encoder = None
self.decoder = None
self.postnet = None
# init tensors
self.embedded_speakers = None
self.embedded_speakers_projected = None
# global style token
if self.gst and self.use_gst:
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
self.gst_layer = None
# Capacitron
if self.capacitron_vae and self.use_capacitron_vae:
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim
self.capacitron_vae_layer = None
# additional layers
self.decoder_backward = None
self.coarse_decoder = None
@staticmethod
def _format_aux_input(aux_input: Dict) -> Dict:
"""Set missing fields to their default values"""
if aux_input:
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
return None
#############################
# INIT FUNCTIONS
#############################
def _init_backward_decoder(self):
"""Init the backward decoder for Forward-Backward decoding."""
self.decoder_backward = copy.deepcopy(self.decoder)
def _init_coarse_decoder(self):
"""Init the coarse decoder for Double-Decoder Consistency."""
self.coarse_decoder = copy.deepcopy(self.decoder)
self.coarse_decoder.r_init = self.ddc_r
self.coarse_decoder.set_r(self.ddc_r)
#############################
# CORE FUNCTIONS
#############################
@abstractmethod
def forward(self):
pass
@abstractmethod
def inference(self):
pass
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load model checkpoint and set up internals.
Args:
config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file.
eval (bool, optional): whether to load model for evaluation.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"])
# TODO: set r in run-time by taking it from the new config
if "r" in state:
# set r from the state (for compatibility with older checkpoints)
self.decoder.set_r(state["r"])
elif "config" in state:
# set r from config used at training time (for inference)
self.decoder.set_r(state["config"]["r"])
else:
# set r from the new config (for new-models)
self.decoder.set_r(config.r)
if eval:
self.eval()
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
assert not self.training
def get_criterion(self) -> nn.Module:
"""Get the model criterion used in training."""
return TacotronLoss(self.config)
@staticmethod
def init_from_config(config: Coqpit):
"""Initialize model from config."""
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
tokenizer = TTSTokenizer.init_from_config(config)
speaker_manager = SpeakerManager.init_from_config(config)
return BaseTacotron(config, ap, tokenizer, speaker_manager)
##########################
# TEST AND LOG FUNCTIONS #
##########################
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Args:
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_test_aux_input()
for idx, sen in enumerate(test_sentences):
outputs_dict = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False
)
return {"figures": test_figures, "audios": test_audios}
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
logger.test_figures(steps, outputs["figures"])
#############################
# COMMON COMPUTE FUNCTIONS
#############################
def compute_masks(self, text_lengths, mel_lengths):
"""Compute masks against sequence paddings."""
# B x T_in_max (boolean)
input_mask = sequence_mask(text_lengths)
output_mask = None
if mel_lengths is not None:
max_len = mel_lengths.max()
r = self.decoder.r
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
output_mask = sequence_mask(mel_lengths, max_len=max_len)
return input_mask, output_mask
def _backward_pass(self, mel_specs, encoder_outputs, mask):
"""Run backwards decoder"""
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
return decoder_outputs_b, alignments_b
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
"""Double Decoder Consistency"""
T = mel_specs.shape[1]
if T % self.coarse_decoder.r > 0:
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
encoder_outputs.detach(), mel_specs, input_mask
)
# scale_factor = self.decoder.r_init / self.decoder.r
alignments_backward = torch.nn.functional.interpolate(
alignments_backward.transpose(1, 2),
size=alignments.shape[1],
mode="nearest",
).transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
return decoder_outputs_backward, alignments_backward
#############################
# EMBEDDING FUNCTIONS
#############################
def compute_gst(self, inputs, style_input, speaker_embedding=None):
"""Compute global style token"""
if isinstance(style_input, dict):
# multiply each style token with a weight
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
# ignore style token and return zero tensor
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
else:
# compute style tokens
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
return inputs
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
"""Capacitron Variational Autoencoder"""
(
VAE_outputs,
posterior_distribution,
prior_distribution,
capacitron_beta,
) = self.capacitron_vae_layer(
reference_mel_info,
text_info,
speaker_embedding, # pylint: disable=not-callable
)
VAE_outputs = VAE_outputs.to(inputs.device)
encoder_output = self._concat_speaker_embedding(
inputs, VAE_outputs
) # concatenate to the output of the basic tacotron encoder
return (
encoder_output,
posterior_distribution,
prior_distribution,
capacitron_beta,
)
@staticmethod
def _add_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
outputs = outputs + embedded_speakers_
return outputs
@staticmethod
def _concat_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
outputs = torch.cat([outputs, embedded_speakers_], dim=-1)
return outputs
#############################
# CALLBACKS
#############################
def on_epoch_start(self, trainer):
"""Callback for setting values wrt gradual training schedule.
Args:
trainer (TrainerTTS): TTS trainer object that is used to train this model.
"""
if self.gradual_training:
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
trainer.config.r = r
self.decoder.set_r(r)
if trainer.config.bidirectional_decoder:
trainer.model.decoder_backward.set_r(r)
print(f"\n > Number of output frames: {self.decoder.r}")

View File

@ -2,12 +2,12 @@
import os import os
import random import random
from contextlib import contextmanager
from time import time from time import time
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchaudio import torchaudio
from tqdm import tqdm from tqdm import tqdm
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
@ -16,22 +16,14 @@ from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
from TTS.tts.layers.tortoise.clvp import CLVP from TTS.tts.layers.tortoise.clvp import CLVP
from TTS.tts.layers.tortoise.cvvp import CVVP from TTS.tts.layers.tortoise.cvvp import CVVP
from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.diffusion import (
SpacedDiffusion,
get_named_beta_schedule,
space_timesteps,
)
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
from TTS.tts.layers.tortoise.vocoder import VocConf
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
from contextlib import contextmanager
def pad_or_truncate(t, length): def pad_or_truncate(t, length):
""" """
@ -56,9 +48,7 @@ def load_discrete_vocoder_diffuser(
Helper function to load a GaussianDiffusion instance configured for use as a vocoder. Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
""" """
return SpacedDiffusion( return SpacedDiffusion(
use_timesteps=space_timesteps( use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
trained_diffusion_steps, [desired_diffusion_steps]
),
model_mean_type="epsilon", model_mean_type="epsilon",
model_var_type="learned_range", model_var_type="learned_range",
loss_type="mse", loss_type="mse",
@ -137,12 +127,12 @@ def do_spectrogram_diffusion(
noise = torch.randn(output_shape, device=latents.device) * temperature noise = torch.randn(output_shape, device=latents.device) * temperature
mel = diffuser.sample_loop( mel = diffuser.sample_loop(
diffusion_model, diffusion_model,
output_shape, output_shape,
noise=noise, noise=noise,
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
progress=verbose progress=verbose,
) )
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
@ -166,9 +156,7 @@ def classify_audio_clip(clip):
kernel_size=5, kernel_size=5,
distribute_zero_label=False, distribute_zero_label=False,
) )
classifier.load_state_dict( classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu")))
torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu"))
)
clip = clip.cpu().unsqueeze(0) clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1) results = F.softmax(classifier(clip), dim=-1)
return results[0][0] return results[0][0]
@ -238,9 +226,7 @@ class TextToSpeech:
self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed
self.models_dir = models_dir self.models_dir = models_dir
self.autoregressive_batch_size = ( self.autoregressive_batch_size = (
pick_best_batch_size_for_gpu() pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
if autoregressive_batch_size is None
else autoregressive_batch_size
) )
self.enable_redaction = enable_redaction self.enable_redaction = enable_redaction
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -274,9 +260,7 @@ class TextToSpeech:
self.autoregressive.load_state_dict(torch.load(ar_path)) self.autoregressive.load_state_dict(torch.load(ar_path))
self.autoregressive.post_init_gpt2_config(kv_cache) self.autoregressive.post_init_gpt2_config(kv_cache)
diff_path = diff_checkpoint or get_model_path( diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
"diffusion_decoder.pth", models_dir
)
self.diffusion = ( self.diffusion = (
DiffusionTts( DiffusionTts(
model_channels=1024, model_channels=1024,
@ -365,9 +349,7 @@ class TextToSpeech:
.cpu() .cpu()
.eval() .eval()
) )
self.cvvp.load_state_dict( self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
torch.load(get_model_path("cvvp.pth", self.models_dir))
)
def get_conditioning_latents( def get_conditioning_latents(
self, self,
@ -407,11 +389,7 @@ class TextToSpeech:
DURS_CONST = 102400 DURS_CONST = 102400
for ls in voice_samples: for ls in voice_samples:
# The diffuser operates at a sample rate of 24000 (except for the latent inputs) # The diffuser operates at a sample rate of 24000 (except for the latent inputs)
sample = ( sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
torchaudio.functional.resample(ls[0], 22050, 24000)
if original_tortoise
else ls[1]
)
if latent_averaging_mode == 0: if latent_averaging_mode == 0:
sample = pad_or_truncate(sample, DURS_CONST) sample = pad_or_truncate(sample, DURS_CONST)
cond_mel = wav_to_univnet_mel( cond_mel = wav_to_univnet_mel(
@ -426,9 +404,7 @@ class TextToSpeech:
if latent_averaging_mode == 2: if latent_averaging_mode == 2:
temp_diffusion_conds = [] temp_diffusion_conds = []
for chunk in range(ceil(sample.shape[1] / DURS_CONST)): for chunk in range(ceil(sample.shape[1] / DURS_CONST)):
current_sample = sample[ current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST]
:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST
]
current_sample = pad_or_truncate(current_sample, DURS_CONST) current_sample = pad_or_truncate(current_sample, DURS_CONST)
cond_mel = wav_to_univnet_mel( cond_mel = wav_to_univnet_mel(
current_sample.to(self.device), current_sample.to(self.device),
@ -440,9 +416,7 @@ class TextToSpeech:
elif latent_averaging_mode == 2: elif latent_averaging_mode == 2:
temp_diffusion_conds.append(cond_mel) temp_diffusion_conds.append(cond_mel)
if latent_averaging_mode == 2: if latent_averaging_mode == 2:
diffusion_conds.append( diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0))
torch.stack(temp_diffusion_conds).mean(0)
)
diffusion_conds = torch.stack(diffusion_conds, dim=1) diffusion_conds = torch.stack(diffusion_conds, dim=1)
with self.temporary_cuda(self.diffusion) as diffusion: with self.temporary_cuda(self.diffusion) as diffusion:
@ -471,9 +445,7 @@ class TextToSpeech:
) )
) )
with torch.no_grad(): with torch.no_grad():
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion( return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
torch.tensor([0.0])
)
def tts_with_preset(self, text, preset="fast", **kwargs): def tts_with_preset(self, text, preset="fast", **kwargs):
""" """
@ -521,10 +493,7 @@ class TextToSpeech:
"diffusion_iterations": 50, "diffusion_iterations": 50,
"sampler": "ddim", "sampler": "ddim",
}, },
"fast_old": { "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
"num_autoregressive_samples": 96,
"diffusion_iterations": 80
},
"standard": { "standard": {
"num_autoregressive_samples": 256, "num_autoregressive_samples": 256,
"diffusion_iterations": 200, "diffusion_iterations": 200,
@ -618,9 +587,7 @@ class TextToSpeech:
""" """
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
text_tokens = ( text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
)
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
assert ( assert (
text_tokens.shape[-1] < 400 text_tokens.shape[-1] < 400
@ -628,12 +595,7 @@ class TextToSpeech:
auto_conds = None auto_conds = None
if voice_samples is not None: if voice_samples is not None:
( (auto_conditioning, diffusion_conditioning, auto_conds, _,) = self.get_conditioning_latents(
auto_conditioning,
diffusion_conditioning,
auto_conds,
_,
) = self.get_conditioning_latents(
voice_samples, voice_samples,
return_mels=True, return_mels=True,
latent_averaging_mode=latent_averaging_mode, latent_averaging_mode=latent_averaging_mode,
@ -650,10 +612,7 @@ class TextToSpeech:
diffusion_conditioning = diffusion_conditioning.to(self.device) diffusion_conditioning = diffusion_conditioning.to(self.device)
diffuser = load_discrete_vocoder_diffuser( diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=diffusion_iterations, desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler
cond_free=cond_free,
cond_free_k=cond_free_k,
sampler=sampler
) )
# in the case of single_sample, # in the case of single_sample,
@ -664,13 +623,13 @@ class TextToSpeech:
samples = [] samples = []
num_batches = num_autoregressive_samples // self.autoregressive_batch_size num_batches = num_autoregressive_samples // self.autoregressive_batch_size
stop_mel_token = self.autoregressive.stop_mel_token stop_mel_token = self.autoregressive.stop_mel_token
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" calm_token = (
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
)
self.autoregressive = self.autoregressive.to(self.device) self.autoregressive = self.autoregressive.to(self.device)
if verbose: if verbose:
print("Generating autoregressive samples..") print("Generating autoregressive samples..")
with self.temporary_cuda( with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(
self.autoregressive
) as autoregressive, torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=half device_type="cuda", dtype=torch.float16, enabled=half
): ):
for b in tqdm(range(num_batches), disable=not verbose): for b in tqdm(range(num_batches), disable=not verbose):
@ -689,9 +648,7 @@ class TextToSpeech:
padding_needed = max_mel_tokens - codes.shape[1] padding_needed = max_mel_tokens - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes) samples.append(codes)
self.autoregressive_batch_size = ( self.autoregressive_batch_size = orig_batch_size # in the case of single_sample
orig_batch_size # in the case of single_sample
)
clip_results = [] clip_results = []
with self.temporary_cuda(self.clvp) as clvp, torch.autocast( with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
@ -729,9 +686,7 @@ class TextToSpeech:
if cvvp_amount == 1: if cvvp_amount == 1:
clip_results.append(cvvp) clip_results.append(cvvp)
else: else:
clip_results.append( clip_results.append(cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount))
cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount)
)
else: else:
clip_results.append(clvp_res) clip_results.append(clvp_res)
clip_results = torch.cat(clip_results, dim=0) clip_results = torch.cat(clip_results, dim=0)
@ -744,19 +699,14 @@ class TextToSpeech:
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage. # results, but will increase memory usage.
with self.temporary_cuda( with self.temporary_cuda(self.autoregressive) as autoregressive:
self.autoregressive
) as autoregressive:
best_latents = autoregressive( best_latents = autoregressive(
auto_conditioning.repeat(k, 1), auto_conditioning.repeat(k, 1),
text_tokens.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
best_results, best_results,
torch.tensor( torch.tensor(
[ [best_results.shape[-1] * self.autoregressive.mel_length_compression],
best_results.shape[-1]
* self.autoregressive.mel_length_compression
],
device=text_tokens.device, device=text_tokens.device,
), ),
return_latent=True, return_latent=True,
@ -778,9 +728,7 @@ class TextToSpeech:
ctokens += 1 ctokens += 1
else: else:
ctokens = 0 ctokens = 0
if ( if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
ctokens > 8
): # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k] latents = latents[:, :k]
break break
with self.temporary_cuda(self.diffusion) as diffusion: with self.temporary_cuda(self.diffusion) as diffusion:
@ -801,10 +749,7 @@ class TextToSpeech:
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
return clip return clip
wav_candidates = [ wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
potentially_redact(wav_candidate, text)
for wav_candidate in wav_candidates
]
if len(wav_candidates) > 1: if len(wav_candidates) > 1:
res = wav_candidates res = wav_candidates

View File

@ -97,7 +97,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.mel_norm = mel_norm self.mel_norm = mel_norm
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None self.mel_basis = None
self.normalized=normalized self.normalized = normalized
if use_mel: if use_mel:
self._build_mel_basis() self._build_mel_basis()