mirror of https://github.com/coqui-ai/TTS.git
style fix
This commit is contained in:
parent
b892aa925a
commit
18c745ceef
|
@ -1,9 +1,8 @@
|
|||
import os
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
@ -11,6 +10,7 @@ from transformers import LogitsWarper
|
|||
|
||||
from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
|
@ -64,11 +64,11 @@ class QKVAttentionLegacy(nn.Module):
|
|||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
||||
if rel_pos is not None:
|
||||
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
||||
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(
|
||||
bs * self.n_heads, weight.shape[-2], weight.shape[-1]
|
||||
)
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
if mask is not None:
|
||||
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
||||
|
@ -112,7 +112,13 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
if relative_pos_embeddings:
|
||||
self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
|
||||
self.relative_pos_embeddings = RelativePositionBias(
|
||||
scale=(channels // self.num_heads) ** 0.5,
|
||||
causal=False,
|
||||
heads=num_heads,
|
||||
num_buckets=32,
|
||||
max_distance=64,
|
||||
)
|
||||
else:
|
||||
self.relative_pos_embeddings = None
|
||||
|
||||
|
@ -168,9 +174,7 @@ class Downsample(nn.Module):
|
|||
|
||||
stride = factor
|
||||
if use_conv:
|
||||
self.op = nn.Conv1d(
|
||||
self.channels, self.out_channels, ksize, stride=stride, padding=pad
|
||||
)
|
||||
self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
@ -182,15 +186,15 @@ class Downsample(nn.Module):
|
|||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
self,
|
||||
channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
up=False,
|
||||
down=False,
|
||||
kernel_size=3,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
@ -221,17 +225,13 @@ class ResBlock(nn.Module):
|
|||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||
),
|
||||
zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = nn.Conv1d(
|
||||
channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding)
|
||||
else:
|
||||
self.skip_connection = nn.Conv1d(channels, self.out_channels, 1)
|
||||
|
||||
|
@ -249,37 +249,38 @@ class ResBlock(nn.Module):
|
|||
|
||||
|
||||
class AudioMiniEncoder(nn.Module):
|
||||
def __init__(self,
|
||||
spec_dim,
|
||||
embedding_dim,
|
||||
base_channels=128,
|
||||
depth=2,
|
||||
resnet_blocks=2,
|
||||
attn_blocks=4,
|
||||
num_attn_heads=4,
|
||||
dropout=0,
|
||||
downsample_factor=2,
|
||||
kernel_size=3):
|
||||
def __init__(
|
||||
self,
|
||||
spec_dim,
|
||||
embedding_dim,
|
||||
base_channels=128,
|
||||
depth=2,
|
||||
resnet_blocks=2,
|
||||
attn_blocks=4,
|
||||
num_attn_heads=4,
|
||||
dropout=0,
|
||||
downsample_factor=2,
|
||||
kernel_size=3,
|
||||
):
|
||||
super().__init__()
|
||||
self.init = nn.Sequential(
|
||||
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
|
||||
)
|
||||
self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1))
|
||||
ch = base_channels
|
||||
res = []
|
||||
for l in range(depth):
|
||||
for r in range(resnet_blocks):
|
||||
res.append(ResBlock(ch, dropout, kernel_size=kernel_size))
|
||||
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
|
||||
res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
|
||||
ch *= 2
|
||||
self.res = nn.Sequential(*res)
|
||||
self.final = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
nn.Conv1d(ch, embedding_dim, 1)
|
||||
)
|
||||
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
|
||||
attn = []
|
||||
for a in range(attn_blocks):
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads,))
|
||||
attn.append(
|
||||
AttentionBlock(
|
||||
embedding_dim,
|
||||
num_attn_heads,
|
||||
)
|
||||
)
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
@ -291,15 +292,24 @@ class AudioMiniEncoder(nn.Module):
|
|||
return h[:, :, 0]
|
||||
|
||||
|
||||
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../utils/assets/tortoise/mel_norms.pth')
|
||||
DEFAULT_MEL_NORM_FILE = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/mel_norms.pth"
|
||||
)
|
||||
|
||||
|
||||
class TorchMelSpectrogram(nn.Module):
|
||||
def __init__(self, filter_length=1024, hop_length=256,
|
||||
win_length=1024, n_mel_channels=80,
|
||||
mel_fmin=0, mel_fmax=8000,
|
||||
sampling_rate=22050, normalize=False,
|
||||
mel_norm_file=DEFAULT_MEL_NORM_FILE):
|
||||
def __init__(
|
||||
self,
|
||||
filter_length=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mel_channels=80,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
sampling_rate=22050,
|
||||
normalize=False,
|
||||
mel_norm_file=DEFAULT_MEL_NORM_FILE,
|
||||
):
|
||||
super().__init__()
|
||||
# These are the default tacotron values for the MEL spectrogram.
|
||||
self.filter_length = filter_length
|
||||
|
@ -309,16 +319,18 @@ class TorchMelSpectrogram(nn.Module):
|
|||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.sampling_rate = sampling_rate
|
||||
self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
power=2,
|
||||
normalized=normalize,
|
||||
sample_rate=self.sampling_rate,
|
||||
f_min=self.mel_fmin,
|
||||
f_max=self.mel_fmax,
|
||||
n_mels=self.n_mel_channels,
|
||||
norm="slaney")
|
||||
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
n_fft=self.filter_length,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
power=2,
|
||||
normalized=normalize,
|
||||
sample_rate=self.sampling_rate,
|
||||
f_min=self.mel_fmin,
|
||||
f_max=self.mel_fmax,
|
||||
n_mels=self.n_mel_channels,
|
||||
norm="slaney",
|
||||
)
|
||||
self.mel_norm_file = mel_norm_file
|
||||
if self.mel_norm_file is not None:
|
||||
self.mel_norms = torch.load(self.mel_norm_file)
|
||||
|
@ -326,7 +338,9 @@ class TorchMelSpectrogram(nn.Module):
|
|||
self.mel_norms = None
|
||||
|
||||
def forward(self, inp):
|
||||
if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||
if (
|
||||
len(inp.shape) == 3
|
||||
): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio)
|
||||
inp = inp.squeeze(1)
|
||||
assert len(inp.shape) == 2
|
||||
self.mel_stft = self.mel_stft.to(inp.device)
|
||||
|
@ -344,6 +358,7 @@ class CheckpointedLayer(nn.Module):
|
|||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
||||
checkpoint for all other args.
|
||||
"""
|
||||
|
||||
def __init__(self, wrap):
|
||||
super().__init__()
|
||||
self.wrap = wrap
|
||||
|
@ -360,6 +375,7 @@ class CheckpointedXTransformerEncoder(nn.Module):
|
|||
Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid
|
||||
to channels-last that XTransformer expects.
|
||||
"""
|
||||
|
||||
def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs):
|
||||
super().__init__()
|
||||
self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs)
|
||||
|
@ -374,10 +390,10 @@ class CheckpointedXTransformerEncoder(nn.Module):
|
|||
|
||||
def forward(self, x, **kwargs):
|
||||
if self.needs_permute:
|
||||
x = x.permute(0,2,1)
|
||||
x = x.permute(0, 2, 1)
|
||||
h = self.transformer(x, **kwargs)
|
||||
if self.exit_permute:
|
||||
h = h.permute(0,2,1)
|
||||
h = h.permute(0, 2, 1)
|
||||
return h
|
||||
|
||||
|
||||
|
@ -392,9 +408,7 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||
self.mass = mass
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# calculate entropy
|
||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
p = torch.exp(normalized)
|
||||
|
@ -409,15 +423,11 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||
# Remove tokens with cumulative mass above the threshold
|
||||
last_ind = (cumulative_probs < self.mass).sum(dim=1)
|
||||
last_ind[last_ind < 0] = 0
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(
|
||||
1, last_ind.view(-1, 1)
|
||||
)
|
||||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
return scores
|
||||
|
|
|
@ -7,11 +7,10 @@ import numpy as np
|
|||
import torch
|
||||
import torchaudio
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
from TTS.utils.audio.torch_transforms import TorchSTFT
|
||||
|
||||
BUILTIN_VOICES_DIR = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices"
|
||||
)
|
||||
BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices")
|
||||
|
||||
|
||||
def load_wav_to_torch(full_path):
|
||||
|
@ -58,10 +57,7 @@ def read_audio_file(audiopath: str):
|
|||
def load_required_audio(audiopath: str):
|
||||
audio, lsr = read_audio_file(audiopath)
|
||||
|
||||
audios = [
|
||||
torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||
for sampling_rate in (22050, 24000)
|
||||
]
|
||||
audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)]
|
||||
for audio in audios:
|
||||
check_audio(audio, audiopath)
|
||||
|
||||
|
@ -83,9 +79,7 @@ TACOTRON_MEL_MIN = -11.512925148010254
|
|||
|
||||
|
||||
def denormalize_tacotron_mel(norm_mel):
|
||||
return ((norm_mel + 1) / 2) * (
|
||||
TACOTRON_MEL_MAX - TACOTRON_MEL_MIN
|
||||
) + TACOTRON_MEL_MIN
|
||||
return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN
|
||||
|
||||
|
||||
def normalize_tacotron_mel(mel):
|
||||
|
@ -118,11 +112,7 @@ def get_voices(extra_voice_dirs: List[str] = []):
|
|||
for sub in subs:
|
||||
subj = os.path.join(d, sub)
|
||||
if os.path.isdir(subj):
|
||||
voices[sub] = (
|
||||
list(glob(f"{subj}/*.wav"))
|
||||
+ list(glob(f"{subj}/*.mp3"))
|
||||
+ list(glob(f"{subj}/*.pth"))
|
||||
)
|
||||
voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth"))
|
||||
return voices
|
||||
|
||||
|
||||
|
@ -148,9 +138,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
|
|||
for voice in voices:
|
||||
if voice == "random":
|
||||
if len(voices) > 1:
|
||||
print(
|
||||
"Cannot combine a random voice with a non-random voice. Just using a random voice."
|
||||
)
|
||||
print("Cannot combine a random voice with a non-random voice. Just using a random voice.")
|
||||
return None, None
|
||||
clip, latent = load_voice(voice, extra_voice_dirs)
|
||||
if latent is None:
|
||||
|
@ -171,18 +159,21 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []):
|
|||
latents = (latents_0, latents_1)
|
||||
return None, latents
|
||||
|
||||
|
||||
def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"):
|
||||
stft = TorchSTFT(n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
use_mel=True,
|
||||
n_mels=100,
|
||||
sample_rate=24000,
|
||||
mel_fmin=0,
|
||||
mel_fmax=12000)
|
||||
stft = TorchSTFT(
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
use_mel=True,
|
||||
n_mels=100,
|
||||
sample_rate=24000,
|
||||
mel_fmin=0,
|
||||
mel_fmax=12000,
|
||||
)
|
||||
stft = stft.to(device)
|
||||
mel = stft(wav)
|
||||
mel = dynamic_range_compression(mel)
|
||||
if do_normalization:
|
||||
mel = normalize_tacotron_mel(mel)
|
||||
return mel
|
||||
return mel
|
||||
|
|
|
@ -9,6 +9,7 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
|||
|
||||
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper
|
||||
|
||||
|
||||
def null_position_embeddings(range, dim):
|
||||
return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
|
||||
|
||||
|
@ -98,9 +99,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
assert self.cached_mel_emb is not None
|
||||
assert inputs_embeds is None # Not supported by this inference model.
|
||||
assert labels is None # Training not supported by this inference model.
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Create embedding
|
||||
mel_len = self.cached_mel_emb.shape[1]
|
||||
|
@ -109,9 +108,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
text_emb = self.embeddings(text_inputs)
|
||||
text_emb = text_emb + self.text_pos_embedding(text_emb)
|
||||
if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
|
||||
mel_emb = self.cached_mel_emb.repeat_interleave(
|
||||
text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
|
||||
)
|
||||
mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0)
|
||||
else: # this outcome only occurs once per loop in most cases
|
||||
mel_emb = self.cached_mel_emb
|
||||
emb = torch.cat([mel_emb, text_emb], dim=1)
|
||||
|
@ -158,10 +155,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(
|
||||
past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
for past_state in layer_past
|
||||
)
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
@ -210,9 +204,7 @@ class LearnedPositionEmbeddings(nn.Module):
|
|||
return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind]
|
||||
|
||||
|
||||
def build_hf_gpt_transformer(
|
||||
layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing
|
||||
):
|
||||
def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing):
|
||||
"""
|
||||
GPT-2 implemented by the HuggingFace library.
|
||||
"""
|
||||
|
@ -230,9 +222,7 @@ def build_hf_gpt_transformer(
|
|||
)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
# Override the built in positional embeddings
|
||||
del (
|
||||
gpt.wpe
|
||||
) # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
|
||||
del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024)
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
# Built-in token embeddings are unused.
|
||||
del gpt.wte
|
||||
|
@ -251,21 +241,15 @@ class MelEncoder(nn.Module):
|
|||
self.channels = channels
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]
|
||||
),
|
||||
nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels // 16, channels // 2),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]
|
||||
),
|
||||
nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
|
||||
nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
|
||||
nn.GroupNorm(channels // 8, channels),
|
||||
nn.ReLU(),
|
||||
nn.Sequential(
|
||||
*[ResBlock(channels) for _ in range(resblocks_per_reduction)]
|
||||
),
|
||||
nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
|
||||
)
|
||||
self.reduction = 4
|
||||
|
||||
|
@ -317,9 +301,7 @@ class UnifiedVoice(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.number_text_tokens = number_text_tokens
|
||||
self.start_text_token = (
|
||||
number_text_tokens * types if start_text_token is None else start_text_token
|
||||
)
|
||||
self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token
|
||||
self.stop_text_token = 0
|
||||
self.number_mel_codes = number_mel_codes
|
||||
self.start_mel_token = start_mel_token
|
||||
|
@ -331,12 +313,8 @@ class UnifiedVoice(nn.Module):
|
|||
self.model_dim = model_dim
|
||||
self.max_conditioning_inputs = max_conditioning_inputs
|
||||
self.mel_length_compression = mel_length_compression
|
||||
self.conditioning_encoder = ConditioningEncoder(
|
||||
80, model_dim, num_attn_heads=heads
|
||||
)
|
||||
self.text_embedding = nn.Embedding(
|
||||
self.number_text_tokens * types + 1, model_dim
|
||||
)
|
||||
self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads)
|
||||
self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
|
||||
if use_mel_codes_as_input:
|
||||
self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
|
||||
else:
|
||||
|
@ -356,12 +334,8 @@ class UnifiedVoice(nn.Module):
|
|||
checkpointing,
|
||||
)
|
||||
if train_solo_embeddings:
|
||||
self.mel_solo_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
|
||||
)
|
||||
self.text_solo_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, model_dim) * 0.02, requires_grad=True
|
||||
)
|
||||
self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||
self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
|
||||
else:
|
||||
self.mel_solo_embedding = 0
|
||||
self.text_solo_embedding = 0
|
||||
|
@ -414,9 +388,7 @@ class UnifiedVoice(nn.Module):
|
|||
preformatting to create a working TTS model.
|
||||
"""
|
||||
# Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
|
||||
mel_lengths = torch.div(
|
||||
wav_lengths, self.mel_length_compression, rounding_mode="trunc"
|
||||
)
|
||||
mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc")
|
||||
for b in range(len(mel_lengths)):
|
||||
actual_end = (
|
||||
mel_lengths[b] + 1
|
||||
|
@ -436,31 +408,22 @@ class UnifiedVoice(nn.Module):
|
|||
return_latent=False,
|
||||
):
|
||||
if second_inputs is not None:
|
||||
emb = torch.cat(
|
||||
[speech_conditioning_inputs, first_inputs, second_inputs], dim=1
|
||||
)
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
|
||||
else:
|
||||
emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
|
||||
|
||||
gpt_out = self.gpt(
|
||||
inputs_embeds=emb, return_dict=True, output_attentions=get_attns
|
||||
)
|
||||
gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
|
||||
if get_attns:
|
||||
return gpt_out.attentions
|
||||
|
||||
enc = gpt_out.last_hidden_state[
|
||||
:, 1:
|
||||
] # The first logit is tied to the speech_conditioning_input
|
||||
enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input
|
||||
enc = self.final_norm(enc)
|
||||
|
||||
if return_latent:
|
||||
return (
|
||||
enc[
|
||||
:,
|
||||
speech_conditioning_inputs.shape[
|
||||
1
|
||||
] : speech_conditioning_inputs.shape[1]
|
||||
+ first_inputs.shape[1],
|
||||
speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1],
|
||||
],
|
||||
enc[:, -second_inputs.shape[1] :],
|
||||
)
|
||||
|
@ -539,9 +502,7 @@ class UnifiedVoice(nn.Module):
|
|||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token
|
||||
)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
|
||||
text_inputs
|
||||
)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
|
||||
mel_codes, self.start_mel_token, self.stop_mel_token
|
||||
)
|
||||
|
@ -596,15 +557,13 @@ class UnifiedVoice(nn.Module):
|
|||
max_generate_length=None,
|
||||
typical_sampling=False,
|
||||
typical_mass=0.9,
|
||||
**hf_generate_kwargs
|
||||
**hf_generate_kwargs,
|
||||
):
|
||||
text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(
|
||||
text_inputs, self.start_text_token, self.stop_text_token
|
||||
)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(
|
||||
text_inputs
|
||||
)
|
||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
|
||||
|
||||
conds = speech_conditioning_latent.unsqueeze(1)
|
||||
emb = torch.cat([conds, text_emb], dim=1)
|
||||
|
@ -628,20 +587,14 @@ class UnifiedVoice(nn.Module):
|
|||
num_return_sequences % input_tokens.shape[0] == 0
|
||||
), "The number of return sequences must be divisible by the number of input sequences"
|
||||
fake_inputs = fake_inputs.repeat(num_return_sequences, 1)
|
||||
input_tokens = input_tokens.repeat(
|
||||
num_return_sequences // input_tokens.shape[0], 1
|
||||
)
|
||||
input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
|
||||
inputs = torch.cat([fake_inputs, input_tokens], dim=1)
|
||||
|
||||
logits_processor = (
|
||||
LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)])
|
||||
if typical_sampling
|
||||
else LogitsProcessorList()
|
||||
LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList()
|
||||
) # TODO disable this
|
||||
max_length = (
|
||||
trunc_index + self.max_mel_tokens - 1
|
||||
if max_generate_length is None
|
||||
else trunc_index + max_generate_length
|
||||
trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length
|
||||
)
|
||||
gen = self.inference_model.generate(
|
||||
inputs,
|
||||
|
@ -651,7 +604,7 @@ class UnifiedVoice(nn.Module):
|
|||
max_length=max_length,
|
||||
logits_processor=logits_processor,
|
||||
num_return_sequences=num_return_sequences,
|
||||
**hf_generate_kwargs
|
||||
**hf_generate_kwargs,
|
||||
)
|
||||
return gen[:, trunc_index:]
|
||||
|
||||
|
|
|
@ -1,13 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from TTS.tts.layers.tortoise.arch_utils import (
|
||||
AttentionBlock,
|
||||
Downsample,
|
||||
Upsample,
|
||||
normalization,
|
||||
zero_module,
|
||||
)
|
||||
from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
@ -54,19 +48,13 @@ class ResBlock(nn.Module):
|
|||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
nn.Conv1d(
|
||||
self.out_channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
),
|
||||
zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = nn.Conv1d(
|
||||
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding)
|
||||
else:
|
||||
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
|
||||
|
||||
|
@ -104,24 +92,14 @@ class AudioMiniEncoder(nn.Module):
|
|||
self.layers = depth
|
||||
for l in range(depth):
|
||||
for r in range(resnet_blocks):
|
||||
res.append(
|
||||
ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)
|
||||
)
|
||||
res.append(
|
||||
Downsample(
|
||||
ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor
|
||||
)
|
||||
)
|
||||
res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
|
||||
res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor))
|
||||
ch *= 2
|
||||
self.res = nn.Sequential(*res)
|
||||
self.final = nn.Sequential(
|
||||
normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)
|
||||
)
|
||||
self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1))
|
||||
attn = []
|
||||
for a in range(attn_blocks):
|
||||
attn.append(
|
||||
AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)
|
||||
)
|
||||
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
|
||||
self.attn = nn.Sequential(*attn)
|
||||
self.dim = embedding_dim
|
||||
|
||||
|
|
|
@ -12,9 +12,10 @@ def exists(val):
|
|||
return val is not None
|
||||
|
||||
|
||||
def masked_mean(t, mask, dim = 1):
|
||||
t = t.masked_fill(~mask[:, :, None], 0.)
|
||||
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
||||
def masked_mean(t, mask, dim=1):
|
||||
t = t.masked_fill(~mask[:, :, None], 0.0)
|
||||
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
|
||||
|
||||
|
||||
class CLVP(nn.Module):
|
||||
"""
|
||||
|
@ -25,23 +26,23 @@ class CLVP(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim_text=512,
|
||||
dim_speech=512,
|
||||
dim_latent=512,
|
||||
num_text_tokens=256,
|
||||
text_enc_depth=6,
|
||||
text_seq_len=120,
|
||||
text_heads=8,
|
||||
num_speech_tokens=8192,
|
||||
speech_enc_depth=6,
|
||||
speech_heads=8,
|
||||
speech_seq_len=250,
|
||||
text_mask_percentage=0,
|
||||
voice_mask_percentage=0,
|
||||
wav_token_compression=1024,
|
||||
use_xformers=False,
|
||||
self,
|
||||
*,
|
||||
dim_text=512,
|
||||
dim_speech=512,
|
||||
dim_latent=512,
|
||||
num_text_tokens=256,
|
||||
text_enc_depth=6,
|
||||
text_seq_len=120,
|
||||
text_heads=8,
|
||||
num_speech_tokens=8192,
|
||||
speech_enc_depth=6,
|
||||
speech_heads=8,
|
||||
speech_seq_len=250,
|
||||
text_mask_percentage=0,
|
||||
voice_mask_percentage=0,
|
||||
wav_token_compression=1024,
|
||||
use_xformers=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_emb = nn.Embedding(num_text_tokens, dim_text)
|
||||
|
@ -59,13 +60,14 @@ class CLVP(nn.Module):
|
|||
dim=dim_text,
|
||||
depth=text_enc_depth,
|
||||
heads=text_heads,
|
||||
ff_dropout=.1,
|
||||
ff_dropout=0.1,
|
||||
ff_mult=2,
|
||||
attn_dropout=.1,
|
||||
attn_dropout=0.1,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
))
|
||||
),
|
||||
)
|
||||
self.speech_transformer = CheckpointedXTransformerEncoder(
|
||||
needs_permute=False,
|
||||
exit_permute=False,
|
||||
|
@ -74,20 +76,23 @@ class CLVP(nn.Module):
|
|||
dim=dim_speech,
|
||||
depth=speech_enc_depth,
|
||||
heads=speech_heads,
|
||||
ff_dropout=.1,
|
||||
ff_dropout=0.1,
|
||||
ff_mult=2,
|
||||
attn_dropout=.1,
|
||||
attn_dropout=0.1,
|
||||
use_rmsnorm=True,
|
||||
ff_glu=True,
|
||||
rotary_pos_emb=True,
|
||||
))
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth,
|
||||
heads=text_heads)
|
||||
self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech,
|
||||
depth=speech_enc_depth, heads=speech_heads)
|
||||
self.text_transformer = Transformer(
|
||||
causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads
|
||||
)
|
||||
self.speech_transformer = Transformer(
|
||||
causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads
|
||||
)
|
||||
|
||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||
self.temperature = nn.Parameter(torch.tensor(1.0))
|
||||
self.text_mask_percentage = text_mask_percentage
|
||||
self.voice_mask_percentage = voice_mask_percentage
|
||||
self.wav_token_compression = wav_token_compression
|
||||
|
@ -96,12 +101,7 @@ class CLVP(nn.Module):
|
|||
self.text_pos_emb = nn.Embedding(text_seq_len, dim_text)
|
||||
self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text,
|
||||
speech_tokens,
|
||||
return_loss=False
|
||||
):
|
||||
def forward(self, text, speech_tokens, return_loss=False):
|
||||
b, device = text.shape[0], text.device
|
||||
if self.training:
|
||||
text_mask = torch.rand_like(text.float()) > self.text_mask_percentage
|
||||
|
@ -131,25 +131,29 @@ class CLVP(nn.Module):
|
|||
temp = self.temperature.exp()
|
||||
|
||||
if not return_loss:
|
||||
sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp
|
||||
sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp
|
||||
return sim
|
||||
|
||||
sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp
|
||||
sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp
|
||||
labels = torch.arange(b, device=device)
|
||||
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2)
|
||||
clip(torch.randint(0,256,(2,120)),
|
||||
torch.tensor([50,100]),
|
||||
torch.randint(0,8192,(2,250)),
|
||||
torch.tensor([101,102]),
|
||||
return_loss=True)
|
||||
nonloss = clip(torch.randint(0,256,(2,120)),
|
||||
torch.tensor([50,100]),
|
||||
torch.randint(0,8192,(2,250)),
|
||||
torch.tensor([101,102]),
|
||||
return_loss=False)
|
||||
print(nonloss.shape)
|
||||
if __name__ == "__main__":
|
||||
clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2)
|
||||
clip(
|
||||
torch.randint(0, 256, (2, 120)),
|
||||
torch.tensor([50, 100]),
|
||||
torch.randint(0, 8192, (2, 250)),
|
||||
torch.tensor([101, 102]),
|
||||
return_loss=True,
|
||||
)
|
||||
nonloss = clip(
|
||||
torch.randint(0, 256, (2, 120)),
|
||||
torch.tensor([50, 100]),
|
||||
torch.randint(0, 8192, (2, 250)),
|
||||
torch.tensor([101, 102]),
|
||||
return_loss=False,
|
||||
)
|
||||
print(nonloss.shape)
|
||||
|
|
|
@ -17,16 +17,7 @@ def masked_mean(t, mask):
|
|||
|
||||
|
||||
class CollapsingTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dim,
|
||||
output_dims,
|
||||
heads,
|
||||
dropout,
|
||||
depth,
|
||||
mask_percentage=0,
|
||||
**encoder_kwargs
|
||||
):
|
||||
def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs):
|
||||
super().__init__()
|
||||
self.transformer = ContinuousTransformerWrapper(
|
||||
max_seq_len=-1,
|
||||
|
@ -105,9 +96,7 @@ class CVVP(nn.Module):
|
|||
self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False)
|
||||
|
||||
if mel_codes is None:
|
||||
self.speech_emb = nn.Conv1d(
|
||||
mel_channels, model_dim, kernel_size=5, padding=2
|
||||
)
|
||||
self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2)
|
||||
else:
|
||||
self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim)
|
||||
self.speech_transformer = CollapsingTransformer(
|
||||
|
@ -135,9 +124,7 @@ class CVVP(nn.Module):
|
|||
enc_speech = self.speech_transformer(speech_emb)
|
||||
speech_latents = self.to_speech_latent(enc_speech)
|
||||
|
||||
cond_latents, speech_latents = map(
|
||||
lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)
|
||||
)
|
||||
cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents))
|
||||
temp = self.temperature.exp()
|
||||
|
||||
if not return_loss:
|
||||
|
|
|
@ -13,8 +13,8 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch as th
|
||||
from tqdm import tqdm
|
||||
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper
|
||||
|
||||
|
@ -38,18 +38,9 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
|
|||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for th.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ th.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
||||
)
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
|
||||
|
||||
|
||||
def approx_standard_normal_cdf(x):
|
||||
|
@ -112,9 +103,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
|||
scale = 1000 / num_diffusion_timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return np.linspace(
|
||||
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
||||
)
|
||||
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
elif schedule_name == "cosine":
|
||||
return betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
|
@ -149,9 +138,9 @@ class ModelMeanType(enum.Enum):
|
|||
Which type of output the model predicts.
|
||||
"""
|
||||
|
||||
PREVIOUS_X = 'previous_x' # the model predicts x_{t-1}
|
||||
START_X = 'start_x' # the model predicts x_0
|
||||
EPSILON = 'epsilon' # the model predicts epsilon
|
||||
PREVIOUS_X = "previous_x" # the model predicts x_{t-1}
|
||||
START_X = "start_x" # the model predicts x_0
|
||||
EPSILON = "epsilon" # the model predicts epsilon
|
||||
|
||||
|
||||
class ModelVarType(enum.Enum):
|
||||
|
@ -162,17 +151,17 @@ class ModelVarType(enum.Enum):
|
|||
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
||||
"""
|
||||
|
||||
LEARNED = 'learned'
|
||||
FIXED_SMALL = 'fixed_small'
|
||||
FIXED_LARGE = 'fixed_large'
|
||||
LEARNED_RANGE = 'learned_range'
|
||||
LEARNED = "learned"
|
||||
FIXED_SMALL = "fixed_small"
|
||||
FIXED_LARGE = "fixed_large"
|
||||
LEARNED_RANGE = "learned_range"
|
||||
|
||||
|
||||
class LossType(enum.Enum):
|
||||
MSE = 'mse' # use raw MSE loss (and KL when learning variances)
|
||||
RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances)
|
||||
KL = 'kl' # use the variational lower-bound
|
||||
RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB
|
||||
MSE = "mse" # use raw MSE loss (and KL when learning variances)
|
||||
RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances)
|
||||
KL = "kl" # use the variational lower-bound
|
||||
RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB
|
||||
|
||||
def is_vb(self):
|
||||
return self == LossType.KL or self == LossType.RESCALED_KL
|
||||
|
@ -239,22 +228,12 @@ class GaussianDiffusion:
|
|||
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = (
|
||||
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
# log calculation clipped because the posterior variance is 0 at the
|
||||
# beginning of the diffusion chain.
|
||||
self.posterior_log_variance_clipped = np.log(
|
||||
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
||||
)
|
||||
self.posterior_mean_coef1 = (
|
||||
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
self.posterior_mean_coef2 = (
|
||||
(1.0 - self.alphas_cumprod_prev)
|
||||
* np.sqrt(alphas)
|
||||
/ (1.0 - self.alphas_cumprod)
|
||||
)
|
||||
self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
|
||||
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
"""
|
||||
|
@ -264,13 +243,9 @@ class GaussianDiffusion:
|
|||
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
||||
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
||||
"""
|
||||
mean = (
|
||||
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
)
|
||||
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = _extract_into_tensor(
|
||||
self.log_one_minus_alphas_cumprod, t, x_start.shape
|
||||
)
|
||||
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
|
@ -289,8 +264,7 @@ class GaussianDiffusion:
|
|||
assert noise.shape == x_start.shape
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
* noise
|
||||
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def q_posterior_mean_variance(self, x_start, x_t, t):
|
||||
|
@ -306,9 +280,7 @@ class GaussianDiffusion:
|
|||
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(
|
||||
self.posterior_log_variance_clipped, t, x_t.shape
|
||||
)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
|
@ -317,9 +289,7 @@ class GaussianDiffusion:
|
|||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(
|
||||
self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
|
||||
):
|
||||
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
||||
"""
|
||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
||||
the initial x, x_0.
|
||||
|
@ -358,9 +328,7 @@ class GaussianDiffusion:
|
|||
model_log_variance = model_var_values
|
||||
model_variance = th.exp(model_log_variance)
|
||||
else:
|
||||
min_log = _extract_into_tensor(
|
||||
self.posterior_log_variance_clipped, t, x.shape
|
||||
)
|
||||
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
|
@ -398,26 +366,18 @@ class GaussianDiffusion:
|
|||
return x
|
||||
|
||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||
pred_xstart = process_xstart(
|
||||
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
||||
)
|
||||
pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))
|
||||
model_mean = model_output
|
||||
elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
|
||||
if self.model_mean_type == ModelMeanType.START_X:
|
||||
pred_xstart = process_xstart(model_output)
|
||||
else:
|
||||
pred_xstart = process_xstart(
|
||||
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
||||
)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(
|
||||
x_start=pred_xstart, x_t=x, t=t
|
||||
)
|
||||
pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
||||
else:
|
||||
raise NotImplementedError(self.model_mean_type)
|
||||
|
||||
assert (
|
||||
model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
)
|
||||
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
return {
|
||||
"mean": model_mean,
|
||||
"variance": model_variance,
|
||||
|
@ -436,16 +396,12 @@ class GaussianDiffusion:
|
|||
assert x_t.shape == xprev.shape
|
||||
return ( # (xprev - coef2*x_t) / coef1
|
||||
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
|
||||
- _extract_into_tensor(
|
||||
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
|
||||
)
|
||||
* x_t
|
||||
- _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t
|
||||
)
|
||||
|
||||
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
||||
return (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- pred_xstart
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
||||
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
|
@ -463,9 +419,7 @@ class GaussianDiffusion:
|
|||
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
||||
"""
|
||||
gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
|
||||
new_mean = (
|
||||
p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
||||
)
|
||||
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
||||
return new_mean
|
||||
|
||||
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
||||
|
@ -481,16 +435,13 @@ class GaussianDiffusion:
|
|||
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
||||
|
||||
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
||||
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
|
||||
x, self._scale_timesteps(t), **model_kwargs
|
||||
)
|
||||
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs)
|
||||
|
||||
out = p_mean_var.copy()
|
||||
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
||||
out["mean"], _, _ = self.q_posterior_mean_variance(
|
||||
x_start=out["pred_xstart"], x_t=x, t=t
|
||||
)
|
||||
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
||||
return out
|
||||
|
||||
def k_diffusion_sample_loop(
|
||||
self,
|
||||
k_sampler,
|
||||
|
@ -512,9 +463,7 @@ class GaussianDiffusion:
|
|||
|
||||
def model_split(*args, **kwargs):
|
||||
model_output = model(*args, **kwargs)
|
||||
model_epsilon, model_var = th.split(
|
||||
model_output, model_output.shape[1] // 2, dim=1
|
||||
)
|
||||
model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1)
|
||||
return model_epsilon, model_var
|
||||
|
||||
#
|
||||
|
@ -523,9 +472,7 @@ class GaussianDiffusion:
|
|||
print(th.tensor(self.betas))
|
||||
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas))
|
||||
"""
|
||||
noise_schedule = NoiseScheduleVP(
|
||||
schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4
|
||||
)
|
||||
noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4)
|
||||
|
||||
def model_fn_prewrap(x, t, *args, **kwargs):
|
||||
"""
|
||||
|
@ -584,11 +531,10 @@ class GaussianDiffusion:
|
|||
if self.conditioning_free is not True:
|
||||
raise RuntimeError("cond_free must be true")
|
||||
with tqdm(total=self.num_timesteps) as pbar:
|
||||
return self.k_diffusion_sample_loop(
|
||||
K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs
|
||||
)
|
||||
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
|
||||
else:
|
||||
raise RuntimeError("sampler not impl")
|
||||
|
||||
def p_sample(
|
||||
self,
|
||||
model,
|
||||
|
@ -625,13 +571,9 @@ class GaussianDiffusion:
|
|||
model_kwargs=model_kwargs,
|
||||
)
|
||||
noise = th.randn_like(x)
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
||||
if cond_fn is not None:
|
||||
out["mean"] = self.condition_mean(
|
||||
cond_fn, out, x, t, model_kwargs=model_kwargs
|
||||
)
|
||||
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
||||
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
|
@ -758,20 +700,11 @@ class GaussianDiffusion:
|
|||
|
||||
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
||||
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
||||
sigma = (
|
||||
eta
|
||||
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
||||
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
)
|
||||
sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
# Equation 12.
|
||||
noise = th.randn_like(x)
|
||||
mean_pred = (
|
||||
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
||||
+ th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
||||
)
|
||||
nonzero_mask = (
|
||||
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
||||
) # no noise when t == 0
|
||||
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
||||
sample = mean_pred + nonzero_mask * sigma * noise
|
||||
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
|
@ -800,16 +733,12 @@ class GaussianDiffusion:
|
|||
# Usually our model outputs epsilon, but we re-derive it
|
||||
# in case we used x_start or x_prev prediction.
|
||||
eps = (
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
||||
- out["pred_xstart"]
|
||||
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
|
||||
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
||||
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
||||
|
||||
# Equation 12. reversed
|
||||
mean_pred = (
|
||||
out["pred_xstart"] * th.sqrt(alpha_bar_next)
|
||||
+ th.sqrt(1 - alpha_bar_next) * eps
|
||||
)
|
||||
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
||||
|
||||
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
||||
|
||||
|
@ -897,9 +826,7 @@ class GaussianDiffusion:
|
|||
yield out
|
||||
img = out["sample"]
|
||||
|
||||
def _vb_terms_bpd(
|
||||
self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
|
||||
):
|
||||
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
|
||||
"""
|
||||
Get a term for the variational lower-bound.
|
||||
|
||||
|
@ -910,15 +837,9 @@ class GaussianDiffusion:
|
|||
- 'output': a shape [N] tensor of NLLs or KLs.
|
||||
- 'pred_xstart': the x_0 predictions.
|
||||
"""
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
||||
x_start=x_start, x_t=x_t, t=t
|
||||
)
|
||||
out = self.p_mean_variance(
|
||||
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
||||
)
|
||||
kl = normal_kl(
|
||||
true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
|
||||
)
|
||||
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
|
||||
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
|
||||
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
|
||||
kl = mean_flat(kl) / np.log(2.0)
|
||||
|
||||
decoder_nll = -discretized_gaussian_log_likelihood(
|
||||
|
@ -969,7 +890,7 @@ class GaussianDiffusion:
|
|||
model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs)
|
||||
if isinstance(model_outputs, tuple):
|
||||
model_output = model_outputs[0]
|
||||
terms['extra_outputs'] = model_outputs[1:]
|
||||
terms["extra_outputs"] = model_outputs[1:]
|
||||
else:
|
||||
model_output = model_outputs
|
||||
|
||||
|
@ -996,9 +917,7 @@ class GaussianDiffusion:
|
|||
terms["vb"] *= self.num_timesteps / 1000.0
|
||||
|
||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||
target = self.q_posterior_mean_variance(
|
||||
x_start=x_start, x_t=x_t, t=t
|
||||
)[0]
|
||||
target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0]
|
||||
x_start_pred = torch.zeros(x_start) # Not supported.
|
||||
elif self.model_mean_type == ModelMeanType.START_X:
|
||||
target = x_start
|
||||
|
@ -1020,7 +939,9 @@ class GaussianDiffusion:
|
|||
|
||||
return terms
|
||||
|
||||
def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None):
|
||||
def autoregressive_training_losses(
|
||||
self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None
|
||||
):
|
||||
"""
|
||||
Compute training losses for a single timestep.
|
||||
|
||||
|
@ -1068,9 +989,7 @@ class GaussianDiffusion:
|
|||
terms["vb"] *= self.num_timesteps / 1000.0
|
||||
|
||||
if self.model_mean_type == ModelMeanType.PREVIOUS_X:
|
||||
target = self.q_posterior_mean_variance(
|
||||
x_start=x_start, x_t=x_t, t=t
|
||||
)[0]
|
||||
target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0]
|
||||
x_start_pred = torch.zeros(x_start) # Not supported.
|
||||
elif self.model_mean_type == ModelMeanType.START_X:
|
||||
target = x_start
|
||||
|
@ -1105,9 +1024,7 @@ class GaussianDiffusion:
|
|||
batch_size = x_start.shape[0]
|
||||
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
||||
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
||||
kl_prior = normal_kl(
|
||||
mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
|
||||
)
|
||||
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
||||
return mean_flat(kl_prior) / np.log(2.0)
|
||||
|
||||
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
||||
|
@ -1183,9 +1100,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
|||
scale = 1000 / num_diffusion_timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return np.linspace(
|
||||
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
|
||||
)
|
||||
return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
||||
elif schedule_name == "cosine":
|
||||
return betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
|
@ -1219,19 +1134,13 @@ class SpacedDiffusion(GaussianDiffusion):
|
|||
kwargs["betas"] = np.array(new_betas)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def p_mean_variance(
|
||||
self, model, *args, **kwargs
|
||||
): # pylint: disable=signature-differs
|
||||
def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
||||
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def training_losses(
|
||||
self, model, *args, **kwargs
|
||||
): # pylint: disable=signature-differs
|
||||
def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
||||
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
||||
|
||||
def autoregressive_training_losses(
|
||||
self, model, *args, **kwargs
|
||||
): # pylint: disable=signature-differs
|
||||
def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
||||
return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs)
|
||||
|
||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||
|
@ -1244,9 +1153,7 @@ class SpacedDiffusion(GaussianDiffusion):
|
|||
if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel):
|
||||
return model
|
||||
mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel
|
||||
return mod(
|
||||
model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
|
||||
)
|
||||
return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps)
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
# Scaling is done by the wrapped model.
|
||||
|
@ -1281,9 +1188,7 @@ def space_timesteps(num_timesteps, section_counts):
|
|||
for i in range(1, num_timesteps):
|
||||
if len(range(0, num_timesteps, i)) == desired_count:
|
||||
return set(range(0, num_timesteps, i))
|
||||
raise ValueError(
|
||||
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
||||
)
|
||||
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
|
||||
section_counts = [int(x) for x in section_counts.split(",")]
|
||||
size_per = num_timesteps // len(section_counts)
|
||||
extra = num_timesteps % len(section_counts)
|
||||
|
@ -1292,9 +1197,7 @@ def space_timesteps(num_timesteps, section_counts):
|
|||
for i, section_count in enumerate(section_counts):
|
||||
size = size_per + (1 if i < extra else 0)
|
||||
if size < section_count:
|
||||
raise ValueError(
|
||||
f"cannot divide section of {size} steps into {section_count}"
|
||||
)
|
||||
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
|
||||
if section_count <= 1:
|
||||
frac_stride = 1
|
||||
else:
|
||||
|
@ -1315,6 +1218,7 @@ class _WrappedModel:
|
|||
self.timestep_map = timestep_map
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.original_num_steps = original_num_steps
|
||||
|
||||
def __call__(self, x, ts, **kwargs):
|
||||
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
||||
new_ts = map_tensor[ts]
|
||||
|
@ -1323,6 +1227,7 @@ class _WrappedModel:
|
|||
model_output = self.model(x, new_ts, **kwargs)
|
||||
return model_output
|
||||
|
||||
|
||||
class _WrappedAutoregressiveModel:
|
||||
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
|
||||
self.model = model
|
||||
|
@ -1337,6 +1242,7 @@ class _WrappedAutoregressiveModel:
|
|||
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
return self.model(x, x0, new_ts, **kwargs)
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
@ -1350,4 +1256,4 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
|||
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res.expand(broadcast_shape)
|
||||
return res.expand(broadcast_shape)
|
||||
|
|
|
@ -29,11 +29,9 @@ def timestep_embedding(timesteps, dim, max_period=10000):
|
|||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=timesteps.device)
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
device=timesteps.device
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
|
@ -98,17 +96,13 @@ class ResBlock(TimestepBlock):
|
|||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
nn.Conv1d(
|
||||
self.out_channels, self.out_channels, kernel_size, padding=padding
|
||||
),
|
||||
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = nn.Conv1d(
|
||||
channels, self.out_channels, eff_kernel, padding=eff_padding
|
||||
)
|
||||
self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
|
||||
|
||||
def forward(self, x, emb):
|
||||
h = self.in_layers(x)
|
||||
|
@ -137,9 +131,7 @@ class DiffusionLayer(TimestepBlock):
|
|||
dims=1,
|
||||
use_scale_shift_norm=True,
|
||||
)
|
||||
self.attn = AttentionBlock(
|
||||
model_channels, num_heads, relative_pos_embeddings=True
|
||||
)
|
||||
self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
|
||||
|
||||
def forward(self, x, time_emb):
|
||||
y = self.resblk(x, time_emb)
|
||||
|
@ -239,16 +231,11 @@ class DiffusionTts(nn.Module):
|
|||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
)
|
||||
|
||||
self.integrating_conv = nn.Conv1d(
|
||||
model_channels * 2, model_channels, kernel_size=1
|
||||
)
|
||||
self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1)
|
||||
self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DiffusionLayer(model_channels, dropout, num_heads)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
[DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)]
|
||||
+ [
|
||||
ResBlock(
|
||||
model_channels,
|
||||
|
@ -275,9 +262,7 @@ class DiffusionTts(nn.Module):
|
|||
+ list(self.code_converter.parameters())
|
||||
+ list(self.latent_conditioner.parameters())
|
||||
+ list(self.latent_conditioner.parameters()),
|
||||
"timestep_integrator": list(
|
||||
self.conditioning_timestep_integrator.parameters()
|
||||
)
|
||||
"timestep_integrator": list(self.conditioning_timestep_integrator.parameters())
|
||||
+ list(self.integrating_conv.parameters()),
|
||||
"time_embed": list(self.time_embed.parameters()),
|
||||
}
|
||||
|
@ -285,9 +270,7 @@ class DiffusionTts(nn.Module):
|
|||
|
||||
def get_conditioning(self, conditioning_input):
|
||||
speech_conditioning_input = (
|
||||
conditioning_input.unsqueeze(1)
|
||||
if len(conditioning_input.shape) == 3
|
||||
else conditioning_input
|
||||
conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
||||
)
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
|
@ -313,29 +296,20 @@ class DiffusionTts(nn.Module):
|
|||
else:
|
||||
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
code_emb = self.code_norm(code_emb) * (
|
||||
1 + cond_scale.unsqueeze(-1)
|
||||
) + cond_shift.unsqueeze(-1)
|
||||
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
|
||||
unconditioned_batches = torch.zeros(
|
||||
(code_emb.shape[0], 1, 1), device=code_emb.device
|
||||
)
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = (
|
||||
torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
< self.unconditioned_percentage
|
||||
torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage
|
||||
)
|
||||
code_emb = torch.where(
|
||||
unconditioned_batches,
|
||||
self.unconditioned_embedding.repeat(
|
||||
aligned_conditioning.shape[0], 1, 1
|
||||
),
|
||||
self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
||||
code_emb,
|
||||
)
|
||||
expanded_code_emb = F.interpolate(
|
||||
code_emb, size=expected_seq_len, mode="nearest"
|
||||
)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest")
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
|
@ -376,10 +350,7 @@ class DiffusionTts(nn.Module):
|
|||
unused_params = []
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(
|
||||
list(self.code_converter.parameters())
|
||||
+ list(self.code_embedding.parameters())
|
||||
)
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
else:
|
||||
if precomputed_aligned_embeddings is not None:
|
||||
|
@ -390,8 +361,7 @@ class DiffusionTts(nn.Module):
|
|||
)
|
||||
if is_latent(aligned_conditioning):
|
||||
unused_params.extend(
|
||||
list(self.code_converter.parameters())
|
||||
+ list(self.code_embedding.parameters())
|
||||
list(self.code_converter.parameters()) + list(self.code_embedding.parameters())
|
||||
)
|
||||
else:
|
||||
unused_params.extend(list(self.latent_conditioner.parameters()))
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class NoiseScheduleVP:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -107,11 +109,7 @@ class NoiseScheduleVP:
|
|||
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
||||
self.total_N = len(log_alphas)
|
||||
self.T = 1.0
|
||||
self.t_array = (
|
||||
torch.linspace(0.0, 1.0, self.total_N + 1)[1:]
|
||||
.reshape((1, -1))
|
||||
.to(dtype=dtype)
|
||||
)
|
||||
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
|
||||
self.log_alpha_array = log_alphas.reshape(
|
||||
(
|
||||
1,
|
||||
|
@ -131,9 +129,7 @@ class NoiseScheduleVP:
|
|||
/ math.pi
|
||||
- self.cosine_s
|
||||
)
|
||||
self.cosine_log_alpha_0 = math.log(
|
||||
math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
|
||||
)
|
||||
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
|
||||
self.schedule = schedule
|
||||
if schedule == "cosine":
|
||||
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
||||
|
@ -157,11 +153,7 @@ class NoiseScheduleVP:
|
|||
elif self.schedule == "cosine":
|
||||
|
||||
def log_alpha_fn(s):
|
||||
return torch.log(
|
||||
torch.cos(
|
||||
(s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0
|
||||
)
|
||||
)
|
||||
return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
|
||||
|
||||
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
||||
return log_alpha_t
|
||||
|
@ -191,17 +183,11 @@ class NoiseScheduleVP:
|
|||
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
||||
"""
|
||||
if self.schedule == "linear":
|
||||
tmp = (
|
||||
2.0
|
||||
* (self.beta_1 - self.beta_0)
|
||||
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
||||
)
|
||||
tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
||||
Delta = self.beta_0**2 + tmp
|
||||
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
||||
elif self.schedule == "discrete":
|
||||
log_alpha = -0.5 * torch.logaddexp(
|
||||
torch.zeros((1,)).to(lamb.device), -2.0 * lamb
|
||||
)
|
||||
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
|
||||
t = interpolate_fn(
|
||||
log_alpha.reshape((-1, 1)),
|
||||
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
||||
|
@ -345,14 +331,10 @@ def model_wrapper(
|
|||
if model_type == "noise":
|
||||
return output
|
||||
elif model_type == "x_start":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(
|
||||
t_continuous
|
||||
), noise_schedule.marginal_std(t_continuous)
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
return (x - alpha_t * output) / sigma_t
|
||||
elif model_type == "v":
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(
|
||||
t_continuous
|
||||
), noise_schedule.marginal_std(t_continuous)
|
||||
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
||||
return alpha_t * output + sigma_t * x
|
||||
elif model_type == "score":
|
||||
sigma_t = noise_schedule.marginal_std(t_continuous)
|
||||
|
@ -482,9 +464,7 @@ class DPM_Solver:
|
|||
p = self.dynamic_thresholding_ratio
|
||||
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
||||
s = expand_dims(
|
||||
torch.maximum(
|
||||
s, self.thresholding_max_val * torch.ones_like(s).to(s.device)
|
||||
),
|
||||
torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)),
|
||||
dims,
|
||||
)
|
||||
x0 = torch.clamp(x0, -s, s) / s
|
||||
|
@ -501,9 +481,7 @@ class DPM_Solver:
|
|||
Return the data prediction model (with corrector).
|
||||
"""
|
||||
noise = self.noise_prediction_fn(x, t)
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
|
||||
t
|
||||
), self.noise_schedule.marginal_std(t)
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||
x0 = (x - sigma_t * noise) / alpha_t
|
||||
if self.correcting_x0_fn is not None:
|
||||
x0 = self.correcting_x0_fn(x0, t)
|
||||
|
@ -536,30 +514,20 @@ class DPM_Solver:
|
|||
if skip_type == "logSNR":
|
||||
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
||||
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
||||
logSNR_steps = torch.linspace(
|
||||
lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
|
||||
).to(device)
|
||||
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
||||
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
||||
elif skip_type == "time_uniform":
|
||||
return torch.linspace(t_T, t_0, N + 1).to(device)
|
||||
elif skip_type == "time_quadratic":
|
||||
t_order = 2
|
||||
t = (
|
||||
torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
|
||||
.pow(t_order)
|
||||
.to(device)
|
||||
)
|
||||
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
||||
return t
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
|
||||
skip_type
|
||||
)
|
||||
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
|
||||
)
|
||||
|
||||
def get_orders_and_timesteps_for_singlestep_solver(
|
||||
self, steps, order, skip_type, t_T, t_0, device
|
||||
):
|
||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||
"""
|
||||
Get the order of each step for sampling by the singlestep DPM-Solver.
|
||||
|
||||
|
@ -664,9 +632,7 @@ class DPM_Solver:
|
|||
dims = x.dim()
|
||||
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
||||
h = lambda_t - lambda_s
|
||||
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(
|
||||
s
|
||||
), ns.marginal_log_mean_coeff(t)
|
||||
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
||||
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
|
@ -716,11 +682,7 @@ class DPM_Solver:
|
|||
x_t: A pytorch tensor. The approximated solution at time `t`.
|
||||
"""
|
||||
if solver_type not in ["dpmsolver", "taylor"]:
|
||||
raise ValueError(
|
||||
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
|
||||
solver_type
|
||||
)
|
||||
)
|
||||
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
||||
if r1 is None:
|
||||
r1 = 0.5
|
||||
ns = self.noise_schedule
|
||||
|
@ -766,10 +728,7 @@ class DPM_Solver:
|
|||
|
||||
if model_s is None:
|
||||
model_s = self.model_fn(x, s)
|
||||
x_s1 = (
|
||||
torch.exp(log_alpha_s1 - log_alpha_s) * x
|
||||
- (sigma_s1 * phi_11) * model_s
|
||||
)
|
||||
x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
|
||||
model_s1 = self.model_fn(x_s1, s1)
|
||||
if solver_type == "dpmsolver":
|
||||
x_t = (
|
||||
|
@ -820,11 +779,7 @@ class DPM_Solver:
|
|||
x_t: A pytorch tensor. The approximated solution at time `t`.
|
||||
"""
|
||||
if solver_type not in ["dpmsolver", "taylor"]:
|
||||
raise ValueError(
|
||||
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
|
||||
solver_type
|
||||
)
|
||||
)
|
||||
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
||||
if r1 is None:
|
||||
r1 = 1.0 / 3.0
|
||||
if r2 is None:
|
||||
|
@ -901,9 +856,7 @@ class DPM_Solver:
|
|||
if model_s is None:
|
||||
model_s = self.model_fn(x, s)
|
||||
if model_s1 is None:
|
||||
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (
|
||||
sigma_s1 * phi_11
|
||||
) * model_s
|
||||
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
|
||||
model_s1 = self.model_fn(x_s1, s1)
|
||||
x_s2 = (
|
||||
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
|
||||
|
@ -934,9 +887,7 @@ class DPM_Solver:
|
|||
else:
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_second_update(
|
||||
self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
|
||||
):
|
||||
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
|
||||
"""
|
||||
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
|
||||
|
||||
|
@ -951,11 +902,7 @@ class DPM_Solver:
|
|||
x_t: A pytorch tensor. The approximated solution at time `t`.
|
||||
"""
|
||||
if solver_type not in ["dpmsolver", "taylor"]:
|
||||
raise ValueError(
|
||||
"'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(
|
||||
solver_type
|
||||
)
|
||||
)
|
||||
raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
|
||||
ns = self.noise_schedule
|
||||
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
|
||||
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
|
||||
|
@ -964,9 +911,7 @@ class DPM_Solver:
|
|||
ns.marginal_lambda(t_prev_0),
|
||||
ns.marginal_lambda(t),
|
||||
)
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
|
||||
t_prev_0
|
||||
), ns.marginal_log_mean_coeff(t)
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
|
@ -977,11 +922,7 @@ class DPM_Solver:
|
|||
if self.algorithm_type == "dpmsolver++":
|
||||
phi_1 = torch.expm1(-h)
|
||||
if solver_type == "dpmsolver":
|
||||
x_t = (
|
||||
(sigma_t / sigma_prev_0) * x
|
||||
- (alpha_t * phi_1) * model_prev_0
|
||||
- 0.5 * (alpha_t * phi_1) * D1_0
|
||||
)
|
||||
x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
|
||||
elif solver_type == "taylor":
|
||||
x_t = (
|
||||
(sigma_t / sigma_prev_0) * x
|
||||
|
@ -1004,9 +945,7 @@ class DPM_Solver:
|
|||
)
|
||||
return x_t
|
||||
|
||||
def multistep_dpm_solver_third_update(
|
||||
self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"
|
||||
):
|
||||
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
|
||||
"""
|
||||
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
|
||||
|
||||
|
@ -1029,9 +968,7 @@ class DPM_Solver:
|
|||
ns.marginal_lambda(t_prev_0),
|
||||
ns.marginal_lambda(t),
|
||||
)
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
|
||||
t_prev_0
|
||||
), ns.marginal_log_mean_coeff(t)
|
||||
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
||||
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
||||
alpha_t = torch.exp(log_alpha_t)
|
||||
|
||||
|
@ -1093,9 +1030,7 @@ class DPM_Solver:
|
|||
x_t: A pytorch tensor. The approximated solution at time `t`.
|
||||
"""
|
||||
if order == 1:
|
||||
return self.dpm_solver_first_update(
|
||||
x, s, t, return_intermediate=return_intermediate
|
||||
)
|
||||
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
|
||||
elif order == 2:
|
||||
return self.singlestep_dpm_solver_second_update(
|
||||
x,
|
||||
|
@ -1118,9 +1053,7 @@ class DPM_Solver:
|
|||
else:
|
||||
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
||||
|
||||
def multistep_dpm_solver_update(
|
||||
self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"
|
||||
):
|
||||
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
|
||||
"""
|
||||
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
|
||||
|
||||
|
@ -1136,17 +1069,11 @@ class DPM_Solver:
|
|||
x_t: A pytorch tensor. The approximated solution at time `t`.
|
||||
"""
|
||||
if order == 1:
|
||||
return self.dpm_solver_first_update(
|
||||
x, t_prev_list[-1], t, model_s=model_prev_list[-1]
|
||||
)
|
||||
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
|
||||
elif order == 2:
|
||||
return self.multistep_dpm_solver_second_update(
|
||||
x, model_prev_list, t_prev_list, t, solver_type=solver_type
|
||||
)
|
||||
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
||||
elif order == 3:
|
||||
return self.multistep_dpm_solver_third_update(
|
||||
x, model_prev_list, t_prev_list, t, solver_type=solver_type
|
||||
)
|
||||
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
|
||||
else:
|
||||
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
||||
|
||||
|
@ -1198,9 +1125,7 @@ class DPM_Solver:
|
|||
return self.dpm_solver_first_update(x, s, t, return_intermediate=True)
|
||||
|
||||
def higher_update(x, s, t, **kwargs):
|
||||
return self.singlestep_dpm_solver_second_update(
|
||||
x, s, t, r1=r1, solver_type=solver_type, **kwargs
|
||||
)
|
||||
return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
|
||||
|
||||
elif order == 3:
|
||||
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
|
||||
|
@ -1211,16 +1136,10 @@ class DPM_Solver:
|
|||
)
|
||||
|
||||
def higher_update(x, s, t, **kwargs):
|
||||
return self.singlestep_dpm_solver_third_update(
|
||||
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
|
||||
)
|
||||
return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"For adaptive step size solver, order must be 2 or 3, got {}".format(
|
||||
order
|
||||
)
|
||||
)
|
||||
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
||||
while torch.abs((s - t_0)).mean() > t_err:
|
||||
t = ns.inverse_lambda(lambda_s + h)
|
||||
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
||||
|
@ -1231,9 +1150,7 @@ class DPM_Solver:
|
|||
)
|
||||
|
||||
def norm_fn(v):
|
||||
return torch.sqrt(
|
||||
torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
|
||||
)
|
||||
return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
||||
|
||||
E = norm_fn((x_higher - x_lower) / delta).max()
|
||||
if torch.all(E <= 1.0):
|
||||
|
@ -1259,9 +1176,7 @@ class DPM_Solver:
|
|||
Returns:
|
||||
xt with shape `(t_size, batch_size, *shape)`.
|
||||
"""
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
|
||||
t
|
||||
), self.noise_schedule.marginal_std(t)
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||
if noise is None:
|
||||
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
|
||||
x = x.reshape((-1, *x.shape))
|
||||
|
@ -1468,9 +1383,7 @@ class DPM_Solver:
|
|||
)
|
||||
elif method == "multistep":
|
||||
assert steps >= order
|
||||
timesteps = self.get_time_steps(
|
||||
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
|
||||
)
|
||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
# Init the initial values.
|
||||
step = 0
|
||||
|
@ -1527,10 +1440,7 @@ class DPM_Solver:
|
|||
model_prev_list[-1] = self.model_fn(x, t)
|
||||
elif method in ["singlestep", "singlestep_fixed"]:
|
||||
if method == "singlestep":
|
||||
(
|
||||
timesteps_outer,
|
||||
orders,
|
||||
) = self.get_orders_and_timesteps_for_singlestep_solver(
|
||||
(timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver(
|
||||
steps=steps,
|
||||
order=order,
|
||||
skip_type=skip_type,
|
||||
|
@ -1543,9 +1453,7 @@ class DPM_Solver:
|
|||
orders = [
|
||||
order,
|
||||
] * K
|
||||
timesteps_outer = self.get_time_steps(
|
||||
skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
|
||||
)
|
||||
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
|
||||
for step, order in enumerate(orders):
|
||||
s, t = timesteps_outer[step], timesteps_outer[step + 1]
|
||||
timesteps_inner = self.get_time_steps(
|
||||
|
@ -1559,9 +1467,7 @@ class DPM_Solver:
|
|||
h = lambda_inner[-1] - lambda_inner[0]
|
||||
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
|
||||
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
|
||||
x = self.singlestep_dpm_solver_update(
|
||||
x, s, t, order, solver_type=solver_type, r1=r1, r2=r2
|
||||
)
|
||||
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
|
||||
if self.correcting_xt_fn is not None:
|
||||
x = self.correcting_xt_fn(x, t, step)
|
||||
if return_intermediate:
|
||||
|
@ -1613,9 +1519,7 @@ def interpolate_fn(x, xp, yp):
|
|||
cand_start_idx,
|
||||
),
|
||||
)
|
||||
end_idx = torch.where(
|
||||
torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
|
||||
)
|
||||
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
||||
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
||||
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
||||
start_idx2 = torch.where(
|
||||
|
@ -1628,12 +1532,8 @@ def interpolate_fn(x, xp, yp):
|
|||
),
|
||||
)
|
||||
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
||||
start_y = torch.gather(
|
||||
y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
end_y = torch.gather(
|
||||
y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
|
||||
).squeeze(2)
|
||||
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
||||
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
||||
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
||||
return cand
|
||||
|
||||
|
|
|
@ -40,8 +40,7 @@ class RandomLatentConverter(nn.Module):
|
|||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
*[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)],
|
||||
nn.Linear(channels, channels)
|
||||
*[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels)
|
||||
)
|
||||
self.channels = channels
|
||||
|
||||
|
|
|
@ -95,9 +95,7 @@ def _expand_number(m):
|
|||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(
|
||||
num, andword="", zero="oh", group=2
|
||||
).replace(", ", " ")
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
@ -165,9 +163,7 @@ def lev_distance(s1, s2):
|
|||
if c1 == c2:
|
||||
distances_.append(distances[i1])
|
||||
else:
|
||||
distances_.append(
|
||||
1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
|
||||
)
|
||||
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
|
||||
distances = distances_
|
||||
return distances[-1]
|
||||
|
||||
|
|
|
@ -36,12 +36,8 @@ def route_args(router, args, depth):
|
|||
|
||||
for key in matched_keys:
|
||||
val = args[key]
|
||||
for depth, ((f_args, g_args), routes) in enumerate(
|
||||
zip(routed_args, router[key])
|
||||
):
|
||||
new_f_args, new_g_args = map(
|
||||
lambda route: ({key: val} if route else {}), routes
|
||||
)
|
||||
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
||||
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
||||
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
||||
return routed_args
|
||||
|
||||
|
@ -217,12 +213,8 @@ class Transformer(nn.Module):
|
|||
layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
LayerScale(
|
||||
dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)
|
||||
),
|
||||
LayerScale(
|
||||
dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)
|
||||
),
|
||||
LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)),
|
||||
LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import os
|
||||
try: import gdown
|
||||
|
||||
try:
|
||||
import gdown
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Sorry, gdown is required in order to download the new BigVGAN vocoder.\n"
|
||||
|
@ -11,9 +13,7 @@ import progressbar
|
|||
|
||||
D_STEM = "https://drive.google.com/uc?id="
|
||||
|
||||
DEFAULT_MODELS_DIR = os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "tortoise", "models"
|
||||
)
|
||||
DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models")
|
||||
MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR)
|
||||
MODELS = {
|
||||
"autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth",
|
||||
|
@ -30,6 +30,8 @@ MODELS = {
|
|||
}
|
||||
|
||||
pbar = None
|
||||
|
||||
|
||||
def download_models(specific_models=None):
|
||||
"""
|
||||
Call to download all the models that Tortoise uses.
|
||||
|
@ -62,6 +64,7 @@ def download_models(specific_models=None):
|
|||
request.urlretrieve(url, model_path, show_progress)
|
||||
print("Done.")
|
||||
|
||||
|
||||
def get_model_path(model_name, models_dir=MODELS_DIR):
|
||||
"""
|
||||
Get path to given model, download it if it doesn't exist.
|
||||
|
@ -71,4 +74,4 @@ def get_model_path(model_name, models_dir=MODELS_DIR):
|
|||
model_path = os.path.join(models_dir, model_name)
|
||||
if not os.path.exists(model_path) and models_dir == MODELS_DIR:
|
||||
download_models([model_name])
|
||||
return model_path
|
||||
return model_path
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
|
@ -40,18 +40,12 @@ class KernelPredictor(torch.nn.Module):
|
|||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
kpnet_kernel_channels = (
|
||||
conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
|
||||
) # l_w
|
||||
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
||||
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
||||
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(
|
||||
**kpnet_nonlinear_activation_params
|
||||
),
|
||||
nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_convs = nn.ModuleList()
|
||||
|
@ -69,9 +63,7 @@ class KernelPredictor(torch.nn.Module):
|
|||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(
|
||||
**kpnet_nonlinear_activation_params
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(
|
||||
kpnet_hidden_channels,
|
||||
|
@ -81,9 +73,7 @@ class KernelPredictor(torch.nn.Module):
|
|||
bias=True,
|
||||
)
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(
|
||||
**kpnet_nonlinear_activation_params
|
||||
),
|
||||
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
)
|
||||
self.kernel_conv = nn.utils.weight_norm(
|
||||
|
@ -252,17 +242,11 @@ class LVCBlock(torch.nn.Module):
|
|||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
assert in_length == (
|
||||
kernel_length * hop_size
|
||||
), "length of (x, kernel) is not matched"
|
||||
assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(
|
||||
x, (padding, padding), "constant", 0
|
||||
) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(
|
||||
2, hop_size + 2 * padding, hop_size
|
||||
) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
|
@ -270,12 +254,8 @@ class LVCBlock(torch.nn.Module):
|
|||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(
|
||||
3, 4
|
||||
) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(
|
||||
4, kernel_size, 1
|
||||
) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o.to(memory_format=torch.channels_last_3d)
|
||||
|
@ -334,15 +314,11 @@ class UnivNetGenerator(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
self.conv_pre = nn.utils.weight_norm(
|
||||
nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")
|
||||
)
|
||||
self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect"))
|
||||
|
||||
self.conv_post = nn.Sequential(
|
||||
nn.LeakyReLU(lReLU_slope),
|
||||
nn.utils.weight_norm(
|
||||
nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")
|
||||
),
|
||||
nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
|
@ -399,12 +375,16 @@ class VocType:
|
|||
constructor: Callable[[], nn.Module]
|
||||
model_path: str
|
||||
subkey: Optional[str] = None
|
||||
|
||||
def optionally_index(self, model_dict):
|
||||
if self.subkey is not None:
|
||||
return model_dict[self.subkey]
|
||||
return model_dict
|
||||
|
||||
|
||||
class VocConf(Enum):
|
||||
Univnet = VocType(UnivNetGenerator, "vocoder.pth", 'model_g')
|
||||
Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = UnivNetGenerator()
|
||||
|
|
|
@ -12,9 +12,7 @@ def max_alignment(s1, s2, skip_character="~", record=None):
|
|||
"""
|
||||
if record is None:
|
||||
record = {}
|
||||
assert (
|
||||
skip_character not in s1
|
||||
), f"Found the skip character {skip_character} in the provided string, {s1}"
|
||||
assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
|
||||
if len(s1) == 0:
|
||||
return ""
|
||||
if len(s2) == 0:
|
||||
|
@ -49,15 +47,9 @@ class Wav2VecAlignment:
|
|||
"""
|
||||
|
||||
def __init__(self, device="cuda"):
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained(
|
||||
"jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli"
|
||||
).cpu()
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"facebook/wav2vec2-large-960h"
|
||||
)
|
||||
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
|
||||
"jbetker/tacotron-symbols"
|
||||
)
|
||||
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h")
|
||||
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols")
|
||||
self.device = device
|
||||
|
||||
def align(self, audio, expected_text, audio_sample_rate=24000):
|
||||
|
@ -117,9 +109,7 @@ class Wav2VecAlignment:
|
|||
)
|
||||
|
||||
# Now fix up alignments. Anything with -1 should be interpolated.
|
||||
alignments.append(
|
||||
orig_len
|
||||
) # This'll get removed but makes the algorithm below more readable.
|
||||
alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
|
||||
for i in range(len(alignments)):
|
||||
if alignments[i] == -1:
|
||||
for j in range(i + 1, len(alignments)):
|
||||
|
@ -128,9 +118,7 @@ class Wav2VecAlignment:
|
|||
break
|
||||
for j in range(i, next_found_token):
|
||||
gap = alignments[next_found_token] - alignments[i - 1]
|
||||
alignments[j] = (j - i + 1) * gap // (
|
||||
next_found_token - i + 1
|
||||
) + alignments[i - 1]
|
||||
alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1]
|
||||
|
||||
return alignments[:-1]
|
||||
|
||||
|
@ -140,9 +128,7 @@ class Wav2VecAlignment:
|
|||
splitted = expected_text.split("[")
|
||||
fully_split = [splitted[0]]
|
||||
for spl in splitted[1:]:
|
||||
assert (
|
||||
"]" in spl
|
||||
), 'Every "[" character must be paired with a "]" with no nesting.'
|
||||
assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.'
|
||||
fully_split.extend(spl.split("]"))
|
||||
|
||||
# At this point, fully_split is a list of strings, with every other string being something that should be redacted.
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,31 +0,0 @@
|
|||
from torch import nn
|
||||
from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock
|
||||
|
||||
|
||||
class FramePriorNet(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, num_res_blocks=13, num_conv_blocks=2
|
||||
):
|
||||
super().__init__()
|
||||
self.res_blocks = nn.ModuleList()
|
||||
for idx in range(num_res_blocks):
|
||||
block = Conv1dBNBlock(
|
||||
in_channels if idx == 0 else hidden_channels,
|
||||
out_channels if (idx + 1) == num_res_blocks else hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
1,
|
||||
num_conv_blocks,
|
||||
)
|
||||
self.res_blocks.append(block)
|
||||
def forward(self, x, x_mask=None):
|
||||
if x_mask is None:
|
||||
x_mask = 1.0
|
||||
o = x * x_mask
|
||||
for block in self.res_blocks:
|
||||
res = o
|
||||
o = block(o)
|
||||
o = o + res
|
||||
if x_mask is not None:
|
||||
o = o * x_mask
|
||||
return o
|
|
@ -1,91 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class ReferenceEncoder(nn.Module):
|
||||
"""NN module creating a fixed size prosody embedding from a spectrogram.
|
||||
|
||||
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||
outputs: [batch_size, embedding_dim]
|
||||
"""
|
||||
|
||||
def __init__(self, num_mel, filter):
|
||||
super().__init__()
|
||||
self.num_mel = num_mel
|
||||
start_index = 2
|
||||
end_index = filter / 16
|
||||
i = start_index
|
||||
filt_len = []
|
||||
while i <= end_index:
|
||||
i = i * 2
|
||||
filt_len.append(i)
|
||||
filt_len.append(i)
|
||||
filters = [1] + filt_len
|
||||
num_layers = len(filters) - 1
|
||||
convs = [
|
||||
nn.Conv2d(
|
||||
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.training = False
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
|
||||
|
||||
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
|
||||
self.recurrence = nn.LSTM(
|
||||
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
|
||||
)
|
||||
|
||||
def forward(self, inputs, input_lengths):
|
||||
batch_size = inputs.size(0)
|
||||
x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel]
|
||||
valid_lengths = input_lengths.float() # [batch_size]
|
||||
for conv, bn in zip(self.convs, self.bns):
|
||||
x = conv(x)
|
||||
x = bn(x)
|
||||
x = F.relu(x)
|
||||
|
||||
# Create the post conv width mask based on the valid lengths of the output of the convolution.
|
||||
# The valid lengths for the output of a convolution on varying length inputs is
|
||||
# ceil(input_length/stride) + 1 for stride=3 and padding=2
|
||||
# For example (kernel_size=3, stride=2, padding=2):
|
||||
# 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d
|
||||
# _____
|
||||
# x _____
|
||||
# x _____
|
||||
# x ____
|
||||
# x
|
||||
# x x x x -> Output valid length = 4
|
||||
# Since every example in te batch is zero padded and therefore have separate valid_lengths,
|
||||
# we need to mask off all the values AFTER the valid length for each example in the batch.
|
||||
# Otherwise, the convolutions create noise and a lot of not real information
|
||||
valid_lengths = (valid_lengths / 2).float()
|
||||
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size]
|
||||
post_conv_max_width = x.size(2)
|
||||
|
||||
mask = torch.arange(post_conv_max_width).to(inputs.device).expand(
|
||||
len(valid_lengths), post_conv_max_width
|
||||
) < valid_lengths.unsqueeze(1)
|
||||
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1]
|
||||
x = x * mask
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
# x: 4D tensor [batch_size, post_conv_width,
|
||||
# num_channels==128, post_conv_height]
|
||||
|
||||
post_conv_width = x.size(1)
|
||||
x = x.contiguous().view(batch_size, post_conv_width, -1)
|
||||
# x: 3D tensor [batch_size, post_conv_width,
|
||||
# num_channels*post_conv_height]
|
||||
|
||||
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
|
||||
post_conv_input_lengths = valid_lengths
|
||||
packed_seqs = nn.utils.rnn.pack_padded_sequence(
|
||||
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False
|
||||
) # dynamic rnn sequence padding
|
||||
self.recurrence.flatten_parameters()
|
||||
_, (ht, _) = self.recurrence(packed_seqs)
|
||||
last_output = ht[-1]
|
||||
|
||||
return last_output.to(inputs.device) # [B, 128]
|
|
@ -1,67 +0,0 @@
|
|||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
class VectorQuantization(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, codebook):
|
||||
with torch.no_grad():
|
||||
embedding_size = codebook.size(1)
|
||||
inputs_size = inputs.size()
|
||||
inputs_flatten = inputs.view(-1, embedding_size)
|
||||
|
||||
codebook_sqr = torch.sum(codebook ** 2, dim=1)
|
||||
inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)
|
||||
|
||||
# Compute the distances to the codebook
|
||||
distances = torch.addmm(codebook_sqr + inputs_sqr,
|
||||
inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
|
||||
|
||||
_, indices_flatten = torch.min(distances, dim=1)
|
||||
indices = indices_flatten.view(*inputs_size[:-1])
|
||||
ctx.mark_non_differentiable(indices)
|
||||
|
||||
return indices
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise RuntimeError('Trying to call `.grad()` on graph containing '
|
||||
'`VectorQuantization`. The function `VectorQuantization` '
|
||||
'is not differentiable. Use `VectorQuantizationStraightThrough` '
|
||||
'if you want a straight-through estimator of the gradient.')
|
||||
|
||||
class VectorQuantizationStraightThrough(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, codebook):
|
||||
indices = vq(inputs, codebook)
|
||||
indices_flatten = indices.view(-1)
|
||||
ctx.save_for_backward(indices_flatten, codebook)
|
||||
ctx.mark_non_differentiable(indices_flatten)
|
||||
|
||||
codes_flatten = torch.index_select(codebook, dim=0,
|
||||
index=indices_flatten)
|
||||
codes = codes_flatten.view_as(inputs)
|
||||
|
||||
return (codes, indices_flatten)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output, grad_indices):
|
||||
grad_inputs, grad_codebook = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
# Straight-through estimator
|
||||
grad_inputs = grad_output.clone()
|
||||
if ctx.needs_input_grad[1]:
|
||||
# Gradient wrt. the codebook
|
||||
indices, codebook = ctx.saved_tensors
|
||||
embedding_size = codebook.size(1)
|
||||
|
||||
grad_output_flatten = (grad_output.contiguous()
|
||||
.view(-1, embedding_size))
|
||||
grad_codebook = torch.zeros_like(codebook)
|
||||
grad_codebook.index_add_(0, indices, grad_output_flatten)
|
||||
|
||||
return (grad_inputs, grad_codebook)
|
||||
|
||||
vq = VectorQuantization.apply
|
||||
vq_st = VectorQuantizationStraightThrough.apply
|
||||
__all__ = [vq, vq_st]
|
|
@ -1,305 +0,0 @@
|
|||
import copy
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
|
||||
class BaseTacotron(BaseTTS):
|
||||
"""Base class shared by Tacotron and Tacotron2"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: "TacotronConfig",
|
||||
ap: "AudioProcessor",
|
||||
tokenizer: "TTSTokenizer",
|
||||
speaker_manager: SpeakerManager = None,
|
||||
):
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
# pass all config fields as class attributes
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# layers
|
||||
self.embedding = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.postnet = None
|
||||
|
||||
# init tensors
|
||||
self.embedded_speakers = None
|
||||
self.embedded_speakers_projected = None
|
||||
|
||||
# global style token
|
||||
if self.gst and self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
# Capacitron
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim
|
||||
self.capacitron_vae_layer = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
@staticmethod
|
||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||
"""Set missing fields to their default values"""
|
||||
if aux_input:
|
||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||
return None
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
||||
def _init_backward_decoder(self):
|
||||
"""Init the backward decoder for Forward-Backward decoding."""
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
||||
def _init_coarse_decoder(self):
|
||||
"""Init the coarse decoder for Double-Decoder Consistency."""
|
||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||
self.coarse_decoder.r_init = self.ddc_r
|
||||
self.coarse_decoder.set_r(self.ddc_r)
|
||||
|
||||
#############################
|
||||
# CORE FUNCTIONS
|
||||
#############################
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
pass
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load model checkpoint and set up internals.
|
||||
|
||||
Args:
|
||||
config (Coqpi): model configuration.
|
||||
checkpoint_path (str): path to checkpoint file.
|
||||
eval (bool, optional): whether to load model for evaluation.
|
||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||
"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
# TODO: set r in run-time by taking it from the new config
|
||||
if "r" in state:
|
||||
# set r from the state (for compatibility with older checkpoints)
|
||||
self.decoder.set_r(state["r"])
|
||||
elif "config" in state:
|
||||
# set r from config used at training time (for inference)
|
||||
self.decoder.set_r(state["config"]["r"])
|
||||
else:
|
||||
# set r from the new config (for new-models)
|
||||
self.decoder.set_r(config.r)
|
||||
if eval:
|
||||
self.eval()
|
||||
print(f" > Model's reduction rate `r` is set to: {self.decoder.r}")
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self) -> nn.Module:
|
||||
"""Get the model criterion used in training."""
|
||||
return TacotronLoss(self.config)
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
"""Initialize model from config."""
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
tokenizer = TTSTokenizer.init_from_config(config)
|
||||
speaker_manager = SpeakerManager.init_from_config(config)
|
||||
return BaseTacotron(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
##########################
|
||||
# TEST AND LOG FUNCTIONS #
|
||||
##########################
|
||||
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Args:
|
||||
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_test_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
outputs_dict = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
)
|
||||
return {"figures": test_figures, "audios": test_audios}
|
||||
|
||||
def test_log(
|
||||
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs["figures"])
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_masks(self, text_lengths, mel_lengths):
|
||||
"""Compute masks against sequence paddings."""
|
||||
# B x T_in_max (boolean)
|
||||
input_mask = sequence_mask(text_lengths)
|
||||
output_mask = None
|
||||
if mel_lengths is not None:
|
||||
max_len = mel_lengths.max()
|
||||
r = self.decoder.r
|
||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len)
|
||||
return input_mask, output_mask
|
||||
|
||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
"""Run backwards decoder"""
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
|
||||
)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
|
||||
"""Double Decoder Consistency"""
|
||||
T = mel_specs.shape[1]
|
||||
if T % self.coarse_decoder.r > 0:
|
||||
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
|
||||
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask
|
||||
)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2),
|
||||
size=alignments.shape[1],
|
||||
mode="nearest",
|
||||
).transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
return decoder_outputs_backward, alignments_backward
|
||||
|
||||
#############################
|
||||
# EMBEDDING FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
"""Compute global style token"""
|
||||
if isinstance(style_input, dict):
|
||||
# multiply each style token with a weight
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
for k_token, v_amplifier in style_input.items():
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
# ignore style token and return zero tensor
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
else:
|
||||
# compute style tokens
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
||||
"""Capacitron Variational Autoencoder"""
|
||||
(
|
||||
VAE_outputs,
|
||||
posterior_distribution,
|
||||
prior_distribution,
|
||||
capacitron_beta,
|
||||
) = self.capacitron_vae_layer(
|
||||
reference_mel_info,
|
||||
text_info,
|
||||
speaker_embedding, # pylint: disable=not-callable
|
||||
)
|
||||
|
||||
VAE_outputs = VAE_outputs.to(inputs.device)
|
||||
encoder_output = self._concat_speaker_embedding(
|
||||
inputs, VAE_outputs
|
||||
) # concatenate to the output of the basic tacotron encoder
|
||||
return (
|
||||
encoder_output,
|
||||
posterior_distribution,
|
||||
prior_distribution,
|
||||
capacitron_beta,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + embedded_speakers_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, embedded_speakers_], dim=-1)
|
||||
return outputs
|
||||
|
||||
#############################
|
||||
# CALLBACKS
|
||||
#############################
|
||||
|
||||
def on_epoch_start(self, trainer):
|
||||
"""Callback for setting values wrt gradual training schedule.
|
||||
|
||||
Args:
|
||||
trainer (TrainerTTS): TTS trainer object that is used to train this model.
|
||||
"""
|
||||
if self.gradual_training:
|
||||
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
|
||||
trainer.config.r = r
|
||||
self.decoder.set_r(r)
|
||||
if trainer.config.bidirectional_decoder:
|
||||
trainer.model.decoder_backward.set_r(r)
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
|
@ -2,12 +2,12 @@
|
|||
|
||||
import os
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||
|
@ -16,22 +16,14 @@ from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice
|
|||
from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead
|
||||
from TTS.tts.layers.tortoise.clvp import CLVP
|
||||
from TTS.tts.layers.tortoise.cvvp import CVVP
|
||||
from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps
|
||||
from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts
|
||||
from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter
|
||||
from TTS.tts.layers.tortoise.vocoder import VocConf
|
||||
|
||||
from TTS.tts.layers.tortoise.diffusion import (
|
||||
SpacedDiffusion,
|
||||
get_named_beta_schedule,
|
||||
space_timesteps,
|
||||
)
|
||||
|
||||
from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
|
||||
from TTS.tts.layers.tortoise.vocoder import VocConf
|
||||
from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment
|
||||
|
||||
from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
"""
|
||||
|
@ -56,9 +48,7 @@ def load_discrete_vocoder_diffuser(
|
|||
Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
|
||||
"""
|
||||
return SpacedDiffusion(
|
||||
use_timesteps=space_timesteps(
|
||||
trained_diffusion_steps, [desired_diffusion_steps]
|
||||
),
|
||||
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
|
||||
model_mean_type="epsilon",
|
||||
model_var_type="learned_range",
|
||||
loss_type="mse",
|
||||
|
@ -137,12 +127,12 @@ def do_spectrogram_diffusion(
|
|||
|
||||
noise = torch.randn(output_shape, device=latents.device) * temperature
|
||||
mel = diffuser.sample_loop(
|
||||
diffusion_model,
|
||||
output_shape,
|
||||
noise=noise,
|
||||
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
|
||||
progress=verbose
|
||||
)
|
||||
diffusion_model,
|
||||
output_shape,
|
||||
noise=noise,
|
||||
model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings},
|
||||
progress=verbose,
|
||||
)
|
||||
return denormalize_tacotron_mel(mel)[:, :, :output_seq_len]
|
||||
|
||||
|
||||
|
@ -166,9 +156,7 @@ def classify_audio_clip(clip):
|
|||
kernel_size=5,
|
||||
distribute_zero_label=False,
|
||||
)
|
||||
classifier.load_state_dict(
|
||||
torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu"))
|
||||
)
|
||||
classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu")))
|
||||
clip = clip.cpu().unsqueeze(0)
|
||||
results = F.softmax(classifier(clip), dim=-1)
|
||||
return results[0][0]
|
||||
|
@ -238,9 +226,7 @@ class TextToSpeech:
|
|||
self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed
|
||||
self.models_dir = models_dir
|
||||
self.autoregressive_batch_size = (
|
||||
pick_best_batch_size_for_gpu()
|
||||
if autoregressive_batch_size is None
|
||||
else autoregressive_batch_size
|
||||
pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size
|
||||
)
|
||||
self.enable_redaction = enable_redaction
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
@ -274,9 +260,7 @@ class TextToSpeech:
|
|||
self.autoregressive.load_state_dict(torch.load(ar_path))
|
||||
self.autoregressive.post_init_gpt2_config(kv_cache)
|
||||
|
||||
diff_path = diff_checkpoint or get_model_path(
|
||||
"diffusion_decoder.pth", models_dir
|
||||
)
|
||||
diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir)
|
||||
self.diffusion = (
|
||||
DiffusionTts(
|
||||
model_channels=1024,
|
||||
|
@ -365,9 +349,7 @@ class TextToSpeech:
|
|||
.cpu()
|
||||
.eval()
|
||||
)
|
||||
self.cvvp.load_state_dict(
|
||||
torch.load(get_model_path("cvvp.pth", self.models_dir))
|
||||
)
|
||||
self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir)))
|
||||
|
||||
def get_conditioning_latents(
|
||||
self,
|
||||
|
@ -407,11 +389,7 @@ class TextToSpeech:
|
|||
DURS_CONST = 102400
|
||||
for ls in voice_samples:
|
||||
# The diffuser operates at a sample rate of 24000 (except for the latent inputs)
|
||||
sample = (
|
||||
torchaudio.functional.resample(ls[0], 22050, 24000)
|
||||
if original_tortoise
|
||||
else ls[1]
|
||||
)
|
||||
sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1]
|
||||
if latent_averaging_mode == 0:
|
||||
sample = pad_or_truncate(sample, DURS_CONST)
|
||||
cond_mel = wav_to_univnet_mel(
|
||||
|
@ -426,9 +404,7 @@ class TextToSpeech:
|
|||
if latent_averaging_mode == 2:
|
||||
temp_diffusion_conds = []
|
||||
for chunk in range(ceil(sample.shape[1] / DURS_CONST)):
|
||||
current_sample = sample[
|
||||
:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST
|
||||
]
|
||||
current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST]
|
||||
current_sample = pad_or_truncate(current_sample, DURS_CONST)
|
||||
cond_mel = wav_to_univnet_mel(
|
||||
current_sample.to(self.device),
|
||||
|
@ -440,9 +416,7 @@ class TextToSpeech:
|
|||
elif latent_averaging_mode == 2:
|
||||
temp_diffusion_conds.append(cond_mel)
|
||||
if latent_averaging_mode == 2:
|
||||
diffusion_conds.append(
|
||||
torch.stack(temp_diffusion_conds).mean(0)
|
||||
)
|
||||
diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0))
|
||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||
|
||||
with self.temporary_cuda(self.diffusion) as diffusion:
|
||||
|
@ -471,9 +445,7 @@ class TextToSpeech:
|
|||
)
|
||||
)
|
||||
with torch.no_grad():
|
||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(
|
||||
torch.tensor([0.0])
|
||||
)
|
||||
return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0]))
|
||||
|
||||
def tts_with_preset(self, text, preset="fast", **kwargs):
|
||||
"""
|
||||
|
@ -521,10 +493,7 @@ class TextToSpeech:
|
|||
"diffusion_iterations": 50,
|
||||
"sampler": "ddim",
|
||||
},
|
||||
"fast_old": {
|
||||
"num_autoregressive_samples": 96,
|
||||
"diffusion_iterations": 80
|
||||
},
|
||||
"fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80},
|
||||
"standard": {
|
||||
"num_autoregressive_samples": 256,
|
||||
"diffusion_iterations": 200,
|
||||
|
@ -618,9 +587,7 @@ class TextToSpeech:
|
|||
"""
|
||||
deterministic_seed = self.deterministic_state(seed=use_deterministic_seed)
|
||||
|
||||
text_tokens = (
|
||||
torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
||||
)
|
||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device)
|
||||
text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary.
|
||||
assert (
|
||||
text_tokens.shape[-1] < 400
|
||||
|
@ -628,12 +595,7 @@ class TextToSpeech:
|
|||
|
||||
auto_conds = None
|
||||
if voice_samples is not None:
|
||||
(
|
||||
auto_conditioning,
|
||||
diffusion_conditioning,
|
||||
auto_conds,
|
||||
_,
|
||||
) = self.get_conditioning_latents(
|
||||
(auto_conditioning, diffusion_conditioning, auto_conds, _,) = self.get_conditioning_latents(
|
||||
voice_samples,
|
||||
return_mels=True,
|
||||
latent_averaging_mode=latent_averaging_mode,
|
||||
|
@ -650,10 +612,7 @@ class TextToSpeech:
|
|||
diffusion_conditioning = diffusion_conditioning.to(self.device)
|
||||
|
||||
diffuser = load_discrete_vocoder_diffuser(
|
||||
desired_diffusion_steps=diffusion_iterations,
|
||||
cond_free=cond_free,
|
||||
cond_free_k=cond_free_k,
|
||||
sampler=sampler
|
||||
desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler
|
||||
)
|
||||
|
||||
# in the case of single_sample,
|
||||
|
@ -664,13 +623,13 @@ class TextToSpeech:
|
|||
samples = []
|
||||
num_batches = num_autoregressive_samples // self.autoregressive_batch_size
|
||||
stop_mel_token = self.autoregressive.stop_mel_token
|
||||
calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||
calm_token = (
|
||||
83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output"
|
||||
)
|
||||
self.autoregressive = self.autoregressive.to(self.device)
|
||||
if verbose:
|
||||
print("Generating autoregressive samples..")
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
) as autoregressive, torch.autocast(
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast(
|
||||
device_type="cuda", dtype=torch.float16, enabled=half
|
||||
):
|
||||
for b in tqdm(range(num_batches), disable=not verbose):
|
||||
|
@ -689,9 +648,7 @@ class TextToSpeech:
|
|||
padding_needed = max_mel_tokens - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
|
||||
samples.append(codes)
|
||||
self.autoregressive_batch_size = (
|
||||
orig_batch_size # in the case of single_sample
|
||||
)
|
||||
self.autoregressive_batch_size = orig_batch_size # in the case of single_sample
|
||||
|
||||
clip_results = []
|
||||
with self.temporary_cuda(self.clvp) as clvp, torch.autocast(
|
||||
|
@ -729,9 +686,7 @@ class TextToSpeech:
|
|||
if cvvp_amount == 1:
|
||||
clip_results.append(cvvp)
|
||||
else:
|
||||
clip_results.append(
|
||||
cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount)
|
||||
)
|
||||
clip_results.append(cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount))
|
||||
else:
|
||||
clip_results.append(clvp_res)
|
||||
clip_results = torch.cat(clip_results, dim=0)
|
||||
|
@ -744,19 +699,14 @@ class TextToSpeech:
|
|||
# The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning
|
||||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||
# results, but will increase memory usage.
|
||||
with self.temporary_cuda(
|
||||
self.autoregressive
|
||||
) as autoregressive:
|
||||
with self.temporary_cuda(self.autoregressive) as autoregressive:
|
||||
best_latents = autoregressive(
|
||||
auto_conditioning.repeat(k, 1),
|
||||
text_tokens.repeat(k, 1),
|
||||
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device),
|
||||
best_results,
|
||||
torch.tensor(
|
||||
[
|
||||
best_results.shape[-1]
|
||||
* self.autoregressive.mel_length_compression
|
||||
],
|
||||
[best_results.shape[-1] * self.autoregressive.mel_length_compression],
|
||||
device=text_tokens.device,
|
||||
),
|
||||
return_latent=True,
|
||||
|
@ -778,9 +728,7 @@ class TextToSpeech:
|
|||
ctokens += 1
|
||||
else:
|
||||
ctokens = 0
|
||||
if (
|
||||
ctokens > 8
|
||||
): # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||
if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
|
||||
latents = latents[:, :k]
|
||||
break
|
||||
with self.temporary_cuda(self.diffusion) as diffusion:
|
||||
|
@ -801,10 +749,7 @@ class TextToSpeech:
|
|||
return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1)
|
||||
return clip
|
||||
|
||||
wav_candidates = [
|
||||
potentially_redact(wav_candidate, text)
|
||||
for wav_candidate in wav_candidates
|
||||
]
|
||||
wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
|
||||
|
||||
if len(wav_candidates) > 1:
|
||||
res = wav_candidates
|
||||
|
|
|
@ -97,7 +97,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
self.mel_norm = mel_norm
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
self.normalized=normalized
|
||||
self.normalized = normalized
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
|
|
Loading…
Reference in New Issue