diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index d107b474..d78011f7 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -1,9 +1,8 @@ -import os import functools import math +import os import torch - import torch.nn as nn import torch.nn.functional as F import torchaudio @@ -11,6 +10,7 @@ from transformers import LogitsWarper from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias + def zero_module(module): """ Zero out the parameters of a module and return it. @@ -64,11 +64,11 @@ class QKVAttentionLegacy(nn.Module): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = torch.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards if rel_pos is not None: - weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( + bs * self.n_heads, weight.shape[-2], weight.shape[-1] + ) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) if mask is not None: # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. @@ -112,7 +112,13 @@ class AttentionBlock(nn.Module): self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) if relative_pos_embeddings: - self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) + self.relative_pos_embeddings = RelativePositionBias( + scale=(channels // self.num_heads) ** 0.5, + causal=False, + heads=num_heads, + num_buckets=32, + max_distance=64, + ) else: self.relative_pos_embeddings = None @@ -168,9 +174,7 @@ class Downsample(nn.Module): stride = factor if use_conv: - self.op = nn.Conv1d( - self.channels, self.out_channels, ksize, stride=stride, padding=pad - ) + self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad) else: assert self.channels == self.out_channels self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) @@ -182,15 +186,15 @@ class Downsample(nn.Module): class ResBlock(nn.Module): def __init__( - self, - channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - up=False, - down=False, - kernel_size=3, + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + up=False, + down=False, + kernel_size=3, ): super().__init__() self.channels = channels @@ -221,17 +225,13 @@ class ResBlock(nn.Module): normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding) - ), + zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = nn.Conv1d( - channels, self.out_channels, kernel_size, padding=padding - ) + self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding) else: self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) @@ -249,37 +249,38 @@ class ResBlock(nn.Module): class AudioMiniEncoder(nn.Module): - def __init__(self, - spec_dim, - embedding_dim, - base_channels=128, - depth=2, - resnet_blocks=2, - attn_blocks=4, - num_attn_heads=4, - dropout=0, - downsample_factor=2, - kernel_size=3): + def __init__( + self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3, + ): super().__init__() - self.init = nn.Sequential( - nn.Conv1d(spec_dim, base_channels, 3, padding=1) - ) + self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) ch = base_channels res = [] for l in range(depth): for r in range(resnet_blocks): res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) - res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor)) + res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) ch *= 2 self.res = nn.Sequential(*res) - self.final = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.Conv1d(ch, embedding_dim, 1) - ) + self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) attn = [] for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads,)) + attn.append( + AttentionBlock( + embedding_dim, + num_attn_heads, + ) + ) self.attn = nn.Sequential(*attn) self.dim = embedding_dim @@ -291,15 +292,24 @@ class AudioMiniEncoder(nn.Module): return h[:, :, 0] -DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../utils/assets/tortoise/mel_norms.pth') +DEFAULT_MEL_NORM_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/mel_norms.pth" +) class TorchMelSpectrogram(nn.Module): - def __init__(self, filter_length=1024, hop_length=256, - win_length=1024, n_mel_channels=80, - mel_fmin=0, mel_fmax=8000, - sampling_rate=22050, normalize=False, - mel_norm_file=DEFAULT_MEL_NORM_FILE): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=80, + mel_fmin=0, + mel_fmax=8000, + sampling_rate=22050, + normalize=False, + mel_norm_file=DEFAULT_MEL_NORM_FILE, + ): super().__init__() # These are the default tacotron values for the MEL spectrogram. self.filter_length = filter_length @@ -309,16 +319,18 @@ class TorchMelSpectrogram(nn.Module): self.mel_fmin = mel_fmin self.mel_fmax = mel_fmax self.sampling_rate = sampling_rate - self.mel_stft = torchaudio.transforms.MelSpectrogram(n_fft=self.filter_length, - hop_length=self.hop_length, - win_length=self.win_length, - power=2, - normalized=normalize, - sample_rate=self.sampling_rate, - f_min=self.mel_fmin, - f_max=self.mel_fmax, - n_mels=self.n_mel_channels, - norm="slaney") + self.mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + power=2, + normalized=normalize, + sample_rate=self.sampling_rate, + f_min=self.mel_fmin, + f_max=self.mel_fmax, + n_mels=self.n_mel_channels, + norm="slaney", + ) self.mel_norm_file = mel_norm_file if self.mel_norm_file is not None: self.mel_norms = torch.load(self.mel_norm_file) @@ -326,7 +338,9 @@ class TorchMelSpectrogram(nn.Module): self.mel_norms = None def forward(self, inp): - if len(inp.shape) == 3: # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + if ( + len(inp.shape) == 3 + ): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) inp = inp.squeeze(1) assert len(inp.shape) == 2 self.mel_stft = self.mel_stft.to(inp.device) @@ -344,6 +358,7 @@ class CheckpointedLayer(nn.Module): Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses checkpoint for all other args. """ + def __init__(self, wrap): super().__init__() self.wrap = wrap @@ -360,6 +375,7 @@ class CheckpointedXTransformerEncoder(nn.Module): Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid to channels-last that XTransformer expects. """ + def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): super().__init__() self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) @@ -374,10 +390,10 @@ class CheckpointedXTransformerEncoder(nn.Module): def forward(self, x, **kwargs): if self.needs_permute: - x = x.permute(0,2,1) + x = x.permute(0, 2, 1) h = self.transformer(x, **kwargs) if self.exit_permute: - h = h.permute(0,2,1) + h = h.permute(0, 2, 1) return h @@ -392,9 +408,7 @@ class TypicalLogitsWarper(LogitsWarper): self.mass = mass self.min_tokens_to_keep = min_tokens_to_keep - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # calculate entropy normalized = torch.nn.functional.log_softmax(scores, dim=-1) p = torch.exp(normalized) @@ -409,15 +423,11 @@ class TypicalLogitsWarper(LogitsWarper): # Remove tokens with cumulative mass above the threshold last_ind = (cumulative_probs < self.mass).sum(dim=1) last_ind[last_ind < 0] = 0 - sorted_indices_to_remove = sorted_scores > sorted_scores.gather( - 1, last_ind.view(-1, 1) - ) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) - return scores \ No newline at end of file + return scores diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py index b09b996f..07a3ff0d 100644 --- a/TTS/tts/layers/tortoise/audio_utils.py +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -7,11 +7,10 @@ import numpy as np import torch import torchaudio from scipy.io.wavfile import read + from TTS.utils.audio.torch_transforms import TorchSTFT -BUILTIN_VOICES_DIR = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices" -) +BUILTIN_VOICES_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/voices") def load_wav_to_torch(full_path): @@ -58,10 +57,7 @@ def read_audio_file(audiopath: str): def load_required_audio(audiopath: str): audio, lsr = read_audio_file(audiopath) - audios = [ - torchaudio.functional.resample(audio, lsr, sampling_rate) - for sampling_rate in (22050, 24000) - ] + audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)] for audio in audios: check_audio(audio, audiopath) @@ -83,9 +79,7 @@ TACOTRON_MEL_MIN = -11.512925148010254 def denormalize_tacotron_mel(norm_mel): - return ((norm_mel + 1) / 2) * ( - TACOTRON_MEL_MAX - TACOTRON_MEL_MIN - ) + TACOTRON_MEL_MIN + return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN def normalize_tacotron_mel(mel): @@ -118,11 +112,7 @@ def get_voices(extra_voice_dirs: List[str] = []): for sub in subs: subj = os.path.join(d, sub) if os.path.isdir(subj): - voices[sub] = ( - list(glob(f"{subj}/*.wav")) - + list(glob(f"{subj}/*.mp3")) - + list(glob(f"{subj}/*.pth")) - ) + voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth")) return voices @@ -148,9 +138,7 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []): for voice in voices: if voice == "random": if len(voices) > 1: - print( - "Cannot combine a random voice with a non-random voice. Just using a random voice." - ) + print("Cannot combine a random voice with a non-random voice. Just using a random voice.") return None, None clip, latent = load_voice(voice, extra_voice_dirs) if latent is None: @@ -171,18 +159,21 @@ def load_voices(voices: List[str], extra_voice_dirs: List[str] = []): latents = (latents_0, latents_1) return None, latents + def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"): - stft = TorchSTFT(n_fft=1024, - hop_length=256, - win_length=1024, - use_mel=True, - n_mels=100, - sample_rate=24000, - mel_fmin=0, - mel_fmax=12000) + stft = TorchSTFT( + n_fft=1024, + hop_length=256, + win_length=1024, + use_mel=True, + n_mels=100, + sample_rate=24000, + mel_fmin=0, + mel_fmax=12000, + ) stft = stft.to(device) mel = stft(wav) mel = dynamic_range_compression(mel) if do_normalization: mel = normalize_tacotron_mel(mel) - return mel \ No newline at end of file + return mel diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index 23c05f92..14d881bc 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -9,6 +9,7 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper + def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) @@ -98,9 +99,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): assert self.cached_mel_emb is not None assert inputs_embeds is None # Not supported by this inference model. assert labels is None # Training not supported by this inference model. - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Create embedding mel_len = self.cached_mel_emb.shape[1] @@ -109,9 +108,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): text_emb = self.embeddings(text_inputs) text_emb = text_emb + self.text_pos_embedding(text_emb) if self.cached_mel_emb.shape[0] != text_emb.shape[0]: - mel_emb = self.cached_mel_emb.repeat_interleave( - text_emb.shape[0] // self.cached_mel_emb.shape[0], 0 - ) + mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0) else: # this outcome only occurs once per loop in most cases mel_emb = self.cached_mel_emb emb = torch.cat([mel_emb, text_emb], dim=1) @@ -158,10 +155,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. """ return tuple( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ) + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) for layer_past in past ) @@ -210,9 +204,7 @@ class LearnedPositionEmbeddings(nn.Module): return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] -def build_hf_gpt_transformer( - layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing -): +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): """ GPT-2 implemented by the HuggingFace library. """ @@ -230,9 +222,7 @@ def build_hf_gpt_transformer( ) gpt = GPT2Model(gpt_config) # Override the built in positional embeddings - del ( - gpt.wpe - ) # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024) + del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024) gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Built-in token embeddings are unused. del gpt.wte @@ -251,21 +241,15 @@ class MelEncoder(nn.Module): self.channels = channels self.encoder = nn.Sequential( nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), - nn.Sequential( - *[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)] - ), + nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), nn.GroupNorm(channels // 16, channels // 2), nn.ReLU(), - nn.Sequential( - *[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)] - ), + nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), nn.GroupNorm(channels // 8, channels), nn.ReLU(), - nn.Sequential( - *[ResBlock(channels) for _ in range(resblocks_per_reduction)] - ), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), ) self.reduction = 4 @@ -317,9 +301,7 @@ class UnifiedVoice(nn.Module): super().__init__() self.number_text_tokens = number_text_tokens - self.start_text_token = ( - number_text_tokens * types if start_text_token is None else start_text_token - ) + self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token self.stop_text_token = 0 self.number_mel_codes = number_mel_codes self.start_mel_token = start_mel_token @@ -331,12 +313,8 @@ class UnifiedVoice(nn.Module): self.model_dim = model_dim self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression - self.conditioning_encoder = ConditioningEncoder( - 80, model_dim, num_attn_heads=heads - ) - self.text_embedding = nn.Embedding( - self.number_text_tokens * types + 1, model_dim - ) + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim) if use_mel_codes_as_input: self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) else: @@ -356,12 +334,8 @@ class UnifiedVoice(nn.Module): checkpointing, ) if train_solo_embeddings: - self.mel_solo_embedding = nn.Parameter( - torch.randn(1, 1, model_dim) * 0.02, requires_grad=True - ) - self.text_solo_embedding = nn.Parameter( - torch.randn(1, 1, model_dim) * 0.02, requires_grad=True - ) + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) else: self.mel_solo_embedding = 0 self.text_solo_embedding = 0 @@ -414,9 +388,7 @@ class UnifiedVoice(nn.Module): preformatting to create a working TTS model. """ # Set padding areas within MEL (currently it is coded with the MEL code for ). - mel_lengths = torch.div( - wav_lengths, self.mel_length_compression, rounding_mode="trunc" - ) + mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc") for b in range(len(mel_lengths)): actual_end = ( mel_lengths[b] + 1 @@ -436,31 +408,22 @@ class UnifiedVoice(nn.Module): return_latent=False, ): if second_inputs is not None: - emb = torch.cat( - [speech_conditioning_inputs, first_inputs, second_inputs], dim=1 - ) + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) else: emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) - gpt_out = self.gpt( - inputs_embeds=emb, return_dict=True, output_attentions=get_attns - ) + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) if get_attns: return gpt_out.attentions - enc = gpt_out.last_hidden_state[ - :, 1: - ] # The first logit is tied to the speech_conditioning_input + enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input enc = self.final_norm(enc) if return_latent: return ( enc[ :, - speech_conditioning_inputs.shape[ - 1 - ] : speech_conditioning_inputs.shape[1] - + first_inputs.shape[1], + speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1], ], enc[:, -second_inputs.shape[1] :], ) @@ -539,9 +502,7 @@ class UnifiedVoice(nn.Module): text_inputs, text_targets = self.build_aligned_inputs_and_targets( text_inputs, self.start_text_token, self.stop_text_token ) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding( - text_inputs - ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) mel_codes, mel_targets = self.build_aligned_inputs_and_targets( mel_codes, self.start_mel_token, self.stop_mel_token ) @@ -596,15 +557,13 @@ class UnifiedVoice(nn.Module): max_generate_length=None, typical_sampling=False, typical_mass=0.9, - **hf_generate_kwargs + **hf_generate_kwargs, ): text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets( text_inputs, self.start_text_token, self.stop_text_token ) - text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding( - text_inputs - ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) conds = speech_conditioning_latent.unsqueeze(1) emb = torch.cat([conds, text_emb], dim=1) @@ -628,20 +587,14 @@ class UnifiedVoice(nn.Module): num_return_sequences % input_tokens.shape[0] == 0 ), "The number of return sequences must be divisible by the number of input sequences" fake_inputs = fake_inputs.repeat(num_return_sequences, 1) - input_tokens = input_tokens.repeat( - num_return_sequences // input_tokens.shape[0], 1 - ) + input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) inputs = torch.cat([fake_inputs, input_tokens], dim=1) logits_processor = ( - LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) - if typical_sampling - else LogitsProcessorList() + LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() ) # TODO disable this max_length = ( - trunc_index + self.max_mel_tokens - 1 - if max_generate_length is None - else trunc_index + max_generate_length + trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length ) gen = self.inference_model.generate( inputs, @@ -651,7 +604,7 @@ class UnifiedVoice(nn.Module): max_length=max_length, logits_processor=logits_processor, num_return_sequences=num_return_sequences, - **hf_generate_kwargs + **hf_generate_kwargs, ) return gen[:, trunc_index:] diff --git a/TTS/tts/layers/tortoise/classifier.py b/TTS/tts/layers/tortoise/classifier.py index 0f48a550..8764bb07 100644 --- a/TTS/tts/layers/tortoise/classifier.py +++ b/TTS/tts/layers/tortoise/classifier.py @@ -1,13 +1,7 @@ import torch import torch.nn as nn -from TTS.tts.layers.tortoise.arch_utils import ( - AttentionBlock, - Downsample, - Upsample, - normalization, - zero_module, -) +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module class ResBlock(nn.Module): @@ -54,19 +48,13 @@ class ResBlock(nn.Module): normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - nn.Conv1d( - self.out_channels, self.out_channels, kernel_size, padding=padding - ) - ), + zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = nn.Conv1d( - dims, channels, self.out_channels, kernel_size, padding=padding - ) + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding) else: self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) @@ -104,24 +92,14 @@ class AudioMiniEncoder(nn.Module): self.layers = depth for l in range(depth): for r in range(resnet_blocks): - res.append( - ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size) - ) - res.append( - Downsample( - ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor - ) - ) + res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) ch *= 2 self.res = nn.Sequential(*res) - self.final = nn.Sequential( - normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1) - ) + self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) attn = [] for a in range(attn_blocks): - attn.append( - AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False) - ) + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim diff --git a/TTS/tts/layers/tortoise/clvp.py b/TTS/tts/layers/tortoise/clvp.py index aeffb8eb..69b8c17c 100644 --- a/TTS/tts/layers/tortoise/clvp.py +++ b/TTS/tts/layers/tortoise/clvp.py @@ -12,9 +12,10 @@ def exists(val): return val is not None -def masked_mean(t, mask, dim = 1): - t = t.masked_fill(~mask[:, :, None], 0.) - return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] +def masked_mean(t, mask, dim=1): + t = t.masked_fill(~mask[:, :, None], 0.0) + return t.sum(dim=1) / mask.sum(dim=1)[..., None] + class CLVP(nn.Module): """ @@ -25,23 +26,23 @@ class CLVP(nn.Module): """ def __init__( - self, - *, - dim_text=512, - dim_speech=512, - dim_latent=512, - num_text_tokens=256, - text_enc_depth=6, - text_seq_len=120, - text_heads=8, - num_speech_tokens=8192, - speech_enc_depth=6, - speech_heads=8, - speech_seq_len=250, - text_mask_percentage=0, - voice_mask_percentage=0, - wav_token_compression=1024, - use_xformers=False, + self, + *, + dim_text=512, + dim_speech=512, + dim_latent=512, + num_text_tokens=256, + text_enc_depth=6, + text_seq_len=120, + text_heads=8, + num_speech_tokens=8192, + speech_enc_depth=6, + speech_heads=8, + speech_seq_len=250, + text_mask_percentage=0, + voice_mask_percentage=0, + wav_token_compression=1024, + use_xformers=False, ): super().__init__() self.text_emb = nn.Embedding(num_text_tokens, dim_text) @@ -59,13 +60,14 @@ class CLVP(nn.Module): dim=dim_text, depth=text_enc_depth, heads=text_heads, - ff_dropout=.1, + ff_dropout=0.1, ff_mult=2, - attn_dropout=.1, + attn_dropout=0.1, use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, - )) + ), + ) self.speech_transformer = CheckpointedXTransformerEncoder( needs_permute=False, exit_permute=False, @@ -74,20 +76,23 @@ class CLVP(nn.Module): dim=dim_speech, depth=speech_enc_depth, heads=speech_heads, - ff_dropout=.1, + ff_dropout=0.1, ff_mult=2, - attn_dropout=.1, + attn_dropout=0.1, use_rmsnorm=True, ff_glu=True, rotary_pos_emb=True, - )) + ), + ) else: - self.text_transformer = Transformer(causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, - heads=text_heads) - self.speech_transformer = Transformer(causal=False, seq_len=speech_seq_len, dim=dim_speech, - depth=speech_enc_depth, heads=speech_heads) + self.text_transformer = Transformer( + causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads + ) + self.speech_transformer = Transformer( + causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads + ) - self.temperature = nn.Parameter(torch.tensor(1.)) + self.temperature = nn.Parameter(torch.tensor(1.0)) self.text_mask_percentage = text_mask_percentage self.voice_mask_percentage = voice_mask_percentage self.wav_token_compression = wav_token_compression @@ -96,12 +101,7 @@ class CLVP(nn.Module): self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) - def forward( - self, - text, - speech_tokens, - return_loss=False - ): + def forward(self, text, speech_tokens, return_loss=False): b, device = text.shape[0], text.device if self.training: text_mask = torch.rand_like(text.float()) > self.text_mask_percentage @@ -131,25 +131,29 @@ class CLVP(nn.Module): temp = self.temperature.exp() if not return_loss: - sim = einsum('n d, n d -> n', text_latents, speech_latents) * temp + sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp return sim - sim = einsum('i d, j d -> i j', text_latents, speech_latents) * temp + sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp labels = torch.arange(b, device=device) loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 return loss -if __name__ == '__main__': - clip = CLVP(text_mask_percentage=.2, voice_mask_percentage=.2) - clip(torch.randint(0,256,(2,120)), - torch.tensor([50,100]), - torch.randint(0,8192,(2,250)), - torch.tensor([101,102]), - return_loss=True) - nonloss = clip(torch.randint(0,256,(2,120)), - torch.tensor([50,100]), - torch.randint(0,8192,(2,250)), - torch.tensor([101,102]), - return_loss=False) - print(nonloss.shape) \ No newline at end of file +if __name__ == "__main__": + clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2) + clip( + torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=True, + ) + nonloss = clip( + torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=False, + ) + print(nonloss.shape) diff --git a/TTS/tts/layers/tortoise/cvvp.py b/TTS/tts/layers/tortoise/cvvp.py index 171becaa..215ba3ac 100644 --- a/TTS/tts/layers/tortoise/cvvp.py +++ b/TTS/tts/layers/tortoise/cvvp.py @@ -17,16 +17,7 @@ def masked_mean(t, mask): class CollapsingTransformer(nn.Module): - def __init__( - self, - model_dim, - output_dims, - heads, - dropout, - depth, - mask_percentage=0, - **encoder_kwargs - ): + def __init__(self, model_dim, output_dims, heads, dropout, depth, mask_percentage=0, **encoder_kwargs): super().__init__() self.transformer = ContinuousTransformerWrapper( max_seq_len=-1, @@ -105,9 +96,7 @@ class CVVP(nn.Module): self.to_conditioning_latent = nn.Linear(latent_dim, latent_dim, bias=False) if mel_codes is None: - self.speech_emb = nn.Conv1d( - mel_channels, model_dim, kernel_size=5, padding=2 - ) + self.speech_emb = nn.Conv1d(mel_channels, model_dim, kernel_size=5, padding=2) else: self.speech_emb = ConvFormatEmbedding(mel_codes, model_dim) self.speech_transformer = CollapsingTransformer( @@ -135,9 +124,7 @@ class CVVP(nn.Module): enc_speech = self.speech_transformer(speech_emb) speech_latents = self.to_speech_latent(enc_speech) - cond_latents, speech_latents = map( - lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents) - ) + cond_latents, speech_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (cond_latents, speech_latents)) temp = self.temperature.exp() if not return_loss: diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py index 8dc63ae5..eb9e90df 100644 --- a/TTS/tts/layers/tortoise/diffusion.py +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -13,8 +13,8 @@ import math import numpy as np import torch import torch as th -from tqdm import tqdm from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral +from tqdm import tqdm from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper @@ -38,18 +38,9 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for th.exp(). - logvar1, logvar2 = [ - x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + th.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * th.exp(-logvar2) - ) + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) def approx_standard_normal_cdf(x): @@ -112,9 +103,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): scale = 1000 / num_diffusion_timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 - return np.linspace( - beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 - ) + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) elif schedule_name == "cosine": return betas_for_alpha_bar( num_diffusion_timesteps, @@ -149,9 +138,9 @@ class ModelMeanType(enum.Enum): Which type of output the model predicts. """ - PREVIOUS_X = 'previous_x' # the model predicts x_{t-1} - START_X = 'start_x' # the model predicts x_0 - EPSILON = 'epsilon' # the model predicts epsilon + PREVIOUS_X = "previous_x" # the model predicts x_{t-1} + START_X = "start_x" # the model predicts x_0 + EPSILON = "epsilon" # the model predicts epsilon class ModelVarType(enum.Enum): @@ -162,17 +151,17 @@ class ModelVarType(enum.Enum): values between FIXED_SMALL and FIXED_LARGE, making its job easier. """ - LEARNED = 'learned' - FIXED_SMALL = 'fixed_small' - FIXED_LARGE = 'fixed_large' - LEARNED_RANGE = 'learned_range' + LEARNED = "learned" + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED_RANGE = "learned_range" class LossType(enum.Enum): - MSE = 'mse' # use raw MSE loss (and KL when learning variances) - RESCALED_MSE = 'rescaled_mse' # use raw MSE loss (with RESCALED_KL when learning variances) - KL = 'kl' # use the variational lower-bound - RESCALED_KL = 'rescaled_kl' # like KL, but rescale to estimate the full VLB + MSE = "mse" # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances) + KL = "kl" # use the variational lower-bound + RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB def is_vb(self): return self == LossType.KL or self == LossType.RESCALED_KL @@ -239,22 +228,12 @@ class GaussianDiffusion: self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = ( - betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - ) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. - self.posterior_log_variance_clipped = np.log( - np.append(self.posterior_variance[1], self.posterior_variance[1:]) - ) - self.posterior_mean_coef1 = ( - betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - ) - self.posterior_mean_coef2 = ( - (1.0 - self.alphas_cumprod_prev) - * np.sqrt(alphas) - / (1.0 - self.alphas_cumprod) - ) + self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) def q_mean_variance(self, x_start, t): """ @@ -264,13 +243,9 @@ class GaussianDiffusion: :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ - mean = ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - ) + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = _extract_into_tensor( - self.log_one_minus_alphas_cumprod, t, x_start.shape - ) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def q_sample(self, x_start, t, noise=None): @@ -289,8 +264,7 @@ class GaussianDiffusion: assert noise.shape == x_start.shape return ( _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): @@ -306,9 +280,7 @@ class GaussianDiffusion: + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = _extract_into_tensor( - self.posterior_log_variance_clipped, t, x_t.shape - ) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) assert ( posterior_mean.shape[0] == posterior_variance.shape[0] @@ -317,9 +289,7 @@ class GaussianDiffusion: ) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance( - self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None - ): + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0. @@ -358,9 +328,7 @@ class GaussianDiffusion: model_log_variance = model_var_values model_variance = th.exp(model_log_variance) else: - min_log = _extract_into_tensor( - self.posterior_log_variance_clipped, t, x.shape - ) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 @@ -398,26 +366,18 @@ class GaussianDiffusion: return x if self.model_mean_type == ModelMeanType.PREVIOUS_X: - pred_xstart = process_xstart( - self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) - ) + pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) model_mean = model_output elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: if self.model_mean_type == ModelMeanType.START_X: pred_xstart = process_xstart(model_output) else: - pred_xstart = process_xstart( - self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) - ) - model_mean, _, _ = self.q_posterior_mean_variance( - x_start=pred_xstart, x_t=x, t=t - ) + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) else: raise NotImplementedError(self.model_mean_type) - assert ( - model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape - ) + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape return { "mean": model_mean, "variance": model_variance, @@ -436,16 +396,12 @@ class GaussianDiffusion: assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - - _extract_into_tensor( - self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape - ) - * x_t + - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - pred_xstart + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): @@ -463,9 +419,7 @@ class GaussianDiffusion: This uses the conditioning strategy from Sohl-Dickstein et al. (2015). """ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) - new_mean = ( - p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() - ) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() return new_mean def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): @@ -481,16 +435,13 @@ class GaussianDiffusion: alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) - eps = eps - (1 - alpha_bar).sqrt() * cond_fn( - x, self._scale_timesteps(t), **model_kwargs - ) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs) out = p_mean_var.copy() out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) - out["mean"], _, _ = self.q_posterior_mean_variance( - x_start=out["pred_xstart"], x_t=x, t=t - ) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) return out + def k_diffusion_sample_loop( self, k_sampler, @@ -512,9 +463,7 @@ class GaussianDiffusion: def model_split(*args, **kwargs): model_output = model(*args, **kwargs) - model_epsilon, model_var = th.split( - model_output, model_output.shape[1] // 2, dim=1 - ) + model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1) return model_epsilon, model_var # @@ -523,9 +472,7 @@ class GaussianDiffusion: print(th.tensor(self.betas)) noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) """ - noise_schedule = NoiseScheduleVP( - schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4 - ) + noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4) def model_fn_prewrap(x, t, *args, **kwargs): """ @@ -584,11 +531,10 @@ class GaussianDiffusion: if self.conditioning_free is not True: raise RuntimeError("cond_free must be true") with tqdm(total=self.num_timesteps) as pbar: - return self.k_diffusion_sample_loop( - K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs - ) + return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) else: raise RuntimeError("sampler not impl") + def p_sample( self, model, @@ -625,13 +571,9 @@ class GaussianDiffusion: model_kwargs=model_kwargs, ) noise = th.randn_like(x) - nonzero_mask = ( - (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) - ) # no noise when t == 0 + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 if cond_fn is not None: - out["mean"] = self.condition_mean( - cond_fn, out, x, t, model_kwargs=model_kwargs - ) + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} @@ -758,20 +700,11 @@ class GaussianDiffusion: alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) - sigma = ( - eta - * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) - * th.sqrt(1 - alpha_bar / alpha_bar_prev) - ) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) # Equation 12. noise = th.randn_like(x) - mean_pred = ( - out["pred_xstart"] * th.sqrt(alpha_bar_prev) - + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps - ) - nonzero_mask = ( - (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) - ) # no noise when t == 0 + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 sample = mean_pred + nonzero_mask * sigma * noise return {"sample": sample, "pred_xstart": out["pred_xstart"]} @@ -800,16 +733,12 @@ class GaussianDiffusion: # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - - out["pred_xstart"] + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed - mean_pred = ( - out["pred_xstart"] * th.sqrt(alpha_bar_next) - + th.sqrt(1 - alpha_bar_next) * eps - ) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} @@ -897,9 +826,7 @@ class GaussianDiffusion: yield out img = out["sample"] - def _vb_terms_bpd( - self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None - ): + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): """ Get a term for the variational lower-bound. @@ -910,15 +837,9 @@ class GaussianDiffusion: - 'output': a shape [N] tensor of NLLs or KLs. - 'pred_xstart': the x_0 predictions. """ - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( - x_start=x_start, x_t=x_t, t=t - ) - out = self.p_mean_variance( - model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs - ) - kl = normal_kl( - true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] - ) + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) kl = mean_flat(kl) / np.log(2.0) decoder_nll = -discretized_gaussian_log_likelihood( @@ -969,7 +890,7 @@ class GaussianDiffusion: model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) if isinstance(model_outputs, tuple): model_output = model_outputs[0] - terms['extra_outputs'] = model_outputs[1:] + terms["extra_outputs"] = model_outputs[1:] else: model_output = model_outputs @@ -996,9 +917,7 @@ class GaussianDiffusion: terms["vb"] *= self.num_timesteps / 1000.0 if self.model_mean_type == ModelMeanType.PREVIOUS_X: - target = self.q_posterior_mean_variance( - x_start=x_start, x_t=x_t, t=t - )[0] + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] x_start_pred = torch.zeros(x_start) # Not supported. elif self.model_mean_type == ModelMeanType.START_X: target = x_start @@ -1020,7 +939,9 @@ class GaussianDiffusion: return terms - def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None): + def autoregressive_training_losses( + self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None + ): """ Compute training losses for a single timestep. @@ -1068,9 +989,7 @@ class GaussianDiffusion: terms["vb"] *= self.num_timesteps / 1000.0 if self.model_mean_type == ModelMeanType.PREVIOUS_X: - target = self.q_posterior_mean_variance( - x_start=x_start, x_t=x_t, t=t - )[0] + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] x_start_pred = torch.zeros(x_start) # Not supported. elif self.model_mean_type == ModelMeanType.START_X: target = x_start @@ -1105,9 +1024,7 @@ class GaussianDiffusion: batch_size = x_start.shape[0] t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl( - mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 - ) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) return mean_flat(kl_prior) / np.log(2.0) def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): @@ -1183,9 +1100,7 @@ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): scale = 1000 / num_diffusion_timesteps beta_start = scale * 0.0001 beta_end = scale * 0.02 - return np.linspace( - beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 - ) + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) elif schedule_name == "cosine": return betas_for_alpha_bar( num_diffusion_timesteps, @@ -1219,19 +1134,13 @@ class SpacedDiffusion(GaussianDiffusion): kwargs["betas"] = np.array(new_betas) super().__init__(**kwargs) - def p_mean_variance( - self, model, *args, **kwargs - ): # pylint: disable=signature-differs + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) - def training_losses( - self, model, *args, **kwargs - ): # pylint: disable=signature-differs + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().training_losses(self._wrap_model(model), *args, **kwargs) - def autoregressive_training_losses( - self, model, *args, **kwargs - ): # pylint: disable=signature-differs + def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) def condition_mean(self, cond_fn, *args, **kwargs): @@ -1244,9 +1153,7 @@ class SpacedDiffusion(GaussianDiffusion): if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): return model mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel - return mod( - model, self.timestep_map, self.rescale_timesteps, self.original_num_steps - ) + return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) def _scale_timesteps(self, t): # Scaling is done by the wrapped model. @@ -1281,9 +1188,7 @@ def space_timesteps(num_timesteps, section_counts): for i in range(1, num_timesteps): if len(range(0, num_timesteps, i)) == desired_count: return set(range(0, num_timesteps, i)) - raise ValueError( - f"cannot create exactly {num_timesteps} steps with an integer stride" - ) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") section_counts = [int(x) for x in section_counts.split(",")] size_per = num_timesteps // len(section_counts) extra = num_timesteps % len(section_counts) @@ -1292,9 +1197,7 @@ def space_timesteps(num_timesteps, section_counts): for i, section_count in enumerate(section_counts): size = size_per + (1 if i < extra else 0) if size < section_count: - raise ValueError( - f"cannot divide section of {size} steps into {section_count}" - ) + raise ValueError(f"cannot divide section of {size} steps into {section_count}") if section_count <= 1: frac_stride = 1 else: @@ -1315,6 +1218,7 @@ class _WrappedModel: self.timestep_map = timestep_map self.rescale_timesteps = rescale_timesteps self.original_num_steps = original_num_steps + def __call__(self, x, ts, **kwargs): map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) new_ts = map_tensor[ts] @@ -1323,6 +1227,7 @@ class _WrappedModel: model_output = self.model(x, new_ts, **kwargs) return model_output + class _WrappedAutoregressiveModel: def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): self.model = model @@ -1337,6 +1242,7 @@ class _WrappedAutoregressiveModel: new_ts = new_ts.float() * (1000.0 / self.original_num_steps) return self.model(x, x0, new_ts, **kwargs) + def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. @@ -1350,4 +1256,4 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] - return res.expand(broadcast_shape) \ No newline at end of file + return res.expand(broadcast_shape) diff --git a/TTS/tts/layers/tortoise/diffusion_decoder.py b/TTS/tts/layers/tortoise/diffusion_decoder.py index 6b1300d9..0d3cf769 100644 --- a/TTS/tts/layers/tortoise/diffusion_decoder.py +++ b/TTS/tts/layers/tortoise/diffusion_decoder.py @@ -29,11 +29,9 @@ def timestep_embedding(timesteps, dim, max_period=10000): :return: an [N x dim] Tensor of positional embeddings. """ half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: @@ -98,17 +96,13 @@ class ResBlock(TimestepBlock): normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - nn.Conv1d( - self.out_channels, self.out_channels, kernel_size, padding=padding - ), + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding), ) if self.out_channels == channels: self.skip_connection = nn.Identity() else: - self.skip_connection = nn.Conv1d( - channels, self.out_channels, eff_kernel, padding=eff_padding - ) + self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding) def forward(self, x, emb): h = self.in_layers(x) @@ -137,9 +131,7 @@ class DiffusionLayer(TimestepBlock): dims=1, use_scale_shift_norm=True, ) - self.attn = AttentionBlock( - model_channels, num_heads, relative_pos_embeddings=True - ) + self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) def forward(self, x, time_emb): y = self.resblk(x, time_emb) @@ -239,16 +231,11 @@ class DiffusionTts(nn.Module): DiffusionLayer(model_channels, dropout, num_heads), ) - self.integrating_conv = nn.Conv1d( - model_channels * 2, model_channels, kernel_size=1 - ) + self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1) self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) self.layers = nn.ModuleList( - [ - DiffusionLayer(model_channels, dropout, num_heads) - for _ in range(num_layers) - ] + [DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + [ ResBlock( model_channels, @@ -275,9 +262,7 @@ class DiffusionTts(nn.Module): + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()) + list(self.latent_conditioner.parameters()), - "timestep_integrator": list( - self.conditioning_timestep_integrator.parameters() - ) + "timestep_integrator": list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()), "time_embed": list(self.time_embed.parameters()), } @@ -285,9 +270,7 @@ class DiffusionTts(nn.Module): def get_conditioning(self, conditioning_input): speech_conditioning_input = ( - conditioning_input.unsqueeze(1) - if len(conditioning_input.shape) == 3 - else conditioning_input + conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input ) conds = [] for j in range(speech_conditioning_input.shape[1]): @@ -313,29 +296,20 @@ class DiffusionTts(nn.Module): else: code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) code_emb = self.code_converter(code_emb) - code_emb = self.code_norm(code_emb) * ( - 1 + cond_scale.unsqueeze(-1) - ) + cond_shift.unsqueeze(-1) + code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) - unconditioned_batches = torch.zeros( - (code_emb.shape[0], 1, 1), device=code_emb.device - ) + unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = ( - torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) - < self.unconditioned_percentage + torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage ) code_emb = torch.where( unconditioned_batches, - self.unconditioned_embedding.repeat( - aligned_conditioning.shape[0], 1, 1 - ), + self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), code_emb, ) - expanded_code_emb = F.interpolate( - code_emb, size=expected_seq_len, mode="nearest" - ) + expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest") if not return_code_pred: return expanded_code_emb @@ -376,10 +350,7 @@ class DiffusionTts(nn.Module): unused_params = [] if conditioning_free: code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) - unused_params.extend( - list(self.code_converter.parameters()) - + list(self.code_embedding.parameters()) - ) + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) unused_params.extend(list(self.latent_conditioner.parameters())) else: if precomputed_aligned_embeddings is not None: @@ -390,8 +361,7 @@ class DiffusionTts(nn.Module): ) if is_latent(aligned_conditioning): unused_params.extend( - list(self.code_converter.parameters()) - + list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.code_embedding.parameters()) ) else: unused_params.extend(list(self.latent_conditioner.parameters())) diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py index d4d383d6..cb540577 100644 --- a/TTS/tts/layers/tortoise/dpm_solver.py +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -1,6 +1,8 @@ import math + import torch + class NoiseScheduleVP: def __init__( self, @@ -107,11 +109,7 @@ class NoiseScheduleVP: log_alphas = 0.5 * torch.log(alphas_cumprod) self.total_N = len(log_alphas) self.T = 1.0 - self.t_array = ( - torch.linspace(0.0, 1.0, self.total_N + 1)[1:] - .reshape((1, -1)) - .to(dtype=dtype) - ) + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) self.log_alpha_array = log_alphas.reshape( ( 1, @@ -131,9 +129,7 @@ class NoiseScheduleVP: / math.pi - self.cosine_s ) - self.cosine_log_alpha_0 = math.log( - math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0) - ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) self.schedule = schedule if schedule == "cosine": # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. @@ -157,11 +153,7 @@ class NoiseScheduleVP: elif self.schedule == "cosine": def log_alpha_fn(s): - return torch.log( - torch.cos( - (s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0 - ) - ) + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 return log_alpha_t @@ -191,17 +183,11 @@ class NoiseScheduleVP: Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. """ if self.schedule == "linear": - tmp = ( - 2.0 - * (self.beta_1 - self.beta_0) - * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) - ) + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) Delta = self.beta_0**2 + tmp return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) elif self.schedule == "discrete": - log_alpha = -0.5 * torch.logaddexp( - torch.zeros((1,)).to(lamb.device), -2.0 * lamb - ) + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) t = interpolate_fn( log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), @@ -345,14 +331,10 @@ def model_wrapper( if model_type == "noise": return output elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha( - t_continuous - ), noise_schedule.marginal_std(t_continuous) + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) return (x - alpha_t * output) / sigma_t elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha( - t_continuous - ), noise_schedule.marginal_std(t_continuous) + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) return alpha_t * output + sigma_t * x elif model_type == "score": sigma_t = noise_schedule.marginal_std(t_continuous) @@ -482,9 +464,7 @@ class DPM_Solver: p = self.dynamic_thresholding_ratio s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims( - torch.maximum( - s, self.thresholding_max_val * torch.ones_like(s).to(s.device) - ), + torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims, ) x0 = torch.clamp(x0, -s, s) / s @@ -501,9 +481,7 @@ class DPM_Solver: Return the data prediction model (with corrector). """ noise = self.noise_prediction_fn(x, t) - alpha_t, sigma_t = self.noise_schedule.marginal_alpha( - t - ), self.noise_schedule.marginal_std(t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) x0 = (x - sigma_t * noise) / alpha_t if self.correcting_x0_fn is not None: x0 = self.correcting_x0_fn(x0, t) @@ -536,30 +514,20 @@ class DPM_Solver: if skip_type == "logSNR": lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace( - lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1 - ).to(device) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) return self.noise_schedule.inverse_lambda(logSNR_steps) elif skip_type == "time_uniform": return torch.linspace(t_T, t_0, N + 1).to(device) elif skip_type == "time_quadratic": t_order = 2 - t = ( - torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1) - .pow(t_order) - .to(device) - ) + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) return t else: raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format( - skip_type - ) + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) ) - def get_orders_and_timesteps_for_singlestep_solver( - self, steps, order, skip_type, t_T, t_0, device - ): + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ Get the order of each step for sampling by the singlestep DPM-Solver. @@ -664,9 +632,7 @@ class DPM_Solver: dims = x.dim() lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) h = lambda_t - lambda_s - log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff( - s - ), ns.marginal_log_mean_coeff(t) + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -716,11 +682,7 @@ class DPM_Solver: x_t: A pytorch tensor. The approximated solution at time `t`. """ if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError( - "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format( - solver_type - ) - ) + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) if r1 is None: r1 = 0.5 ns = self.noise_schedule @@ -766,10 +728,7 @@ class DPM_Solver: if model_s is None: model_s = self.model_fn(x, s) - x_s1 = ( - torch.exp(log_alpha_s1 - log_alpha_s) * x - - (sigma_s1 * phi_11) * model_s - ) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s model_s1 = self.model_fn(x_s1, s1) if solver_type == "dpmsolver": x_t = ( @@ -820,11 +779,7 @@ class DPM_Solver: x_t: A pytorch tensor. The approximated solution at time `t`. """ if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError( - "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format( - solver_type - ) - ) + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) if r1 is None: r1 = 1.0 / 3.0 if r2 is None: @@ -901,9 +856,7 @@ class DPM_Solver: if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: - x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - ( - sigma_s1 * phi_11 - ) * model_s + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s model_s1 = self.model_fn(x_s1, s1) x_s2 = ( (torch.exp(log_alpha_s2 - log_alpha_s)) * x @@ -934,9 +887,7 @@ class DPM_Solver: else: return x_t - def multistep_dpm_solver_second_update( - self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver" - ): + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): """ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. @@ -951,11 +902,7 @@ class DPM_Solver: x_t: A pytorch tensor. The approximated solution at time `t`. """ if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError( - "'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format( - solver_type - ) - ) + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) ns = self.noise_schedule model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] @@ -964,9 +911,7 @@ class DPM_Solver: ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t), ) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff( - t_prev_0 - ), ns.marginal_log_mean_coeff(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -977,11 +922,7 @@ class DPM_Solver: if self.algorithm_type == "dpmsolver++": phi_1 = torch.expm1(-h) if solver_type == "dpmsolver": - x_t = ( - (sigma_t / sigma_prev_0) * x - - (alpha_t * phi_1) * model_prev_0 - - 0.5 * (alpha_t * phi_1) * D1_0 - ) + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 elif solver_type == "taylor": x_t = ( (sigma_t / sigma_prev_0) * x @@ -1004,9 +945,7 @@ class DPM_Solver: ) return x_t - def multistep_dpm_solver_third_update( - self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver" - ): + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): """ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. @@ -1029,9 +968,7 @@ class DPM_Solver: ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t), ) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff( - t_prev_0 - ), ns.marginal_log_mean_coeff(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) alpha_t = torch.exp(log_alpha_t) @@ -1093,9 +1030,7 @@ class DPM_Solver: x_t: A pytorch tensor. The approximated solution at time `t`. """ if order == 1: - return self.dpm_solver_first_update( - x, s, t, return_intermediate=return_intermediate - ) + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) elif order == 2: return self.singlestep_dpm_solver_second_update( x, @@ -1118,9 +1053,7 @@ class DPM_Solver: else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def multistep_dpm_solver_update( - self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver" - ): + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): """ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. @@ -1136,17 +1069,11 @@ class DPM_Solver: x_t: A pytorch tensor. The approximated solution at time `t`. """ if order == 1: - return self.dpm_solver_first_update( - x, t_prev_list[-1], t, model_s=model_prev_list[-1] - ) + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) elif order == 2: - return self.multistep_dpm_solver_second_update( - x, model_prev_list, t_prev_list, t, solver_type=solver_type - ) + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) elif order == 3: - return self.multistep_dpm_solver_third_update( - x, model_prev_list, t_prev_list, t, solver_type=solver_type - ) + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) @@ -1198,9 +1125,7 @@ class DPM_Solver: return self.dpm_solver_first_update(x, s, t, return_intermediate=True) def higher_update(x, s, t, **kwargs): - return self.singlestep_dpm_solver_second_update( - x, s, t, r1=r1, solver_type=solver_type, **kwargs - ) + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) elif order == 3: r1, r2 = 1.0 / 3.0, 2.0 / 3.0 @@ -1211,16 +1136,10 @@ class DPM_Solver: ) def higher_update(x, s, t, **kwargs): - return self.singlestep_dpm_solver_third_update( - x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs - ) + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) else: - raise ValueError( - "For adaptive step size solver, order must be 2 or 3, got {}".format( - order - ) - ) + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) while torch.abs((s - t_0)).mean() > t_err: t = ns.inverse_lambda(lambda_s + h) x_lower, lower_noise_kwargs = lower_update(x, s, t) @@ -1231,9 +1150,7 @@ class DPM_Solver: ) def norm_fn(v): - return torch.sqrt( - torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True) - ) + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) E = norm_fn((x_higher - x_lower) / delta).max() if torch.all(E <= 1.0): @@ -1259,9 +1176,7 @@ class DPM_Solver: Returns: xt with shape `(t_size, batch_size, *shape)`. """ - alpha_t, sigma_t = self.noise_schedule.marginal_alpha( - t - ), self.noise_schedule.marginal_std(t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) if noise is None: noise = torch.randn((t.shape[0], *x.shape), device=x.device) x = x.reshape((-1, *x.shape)) @@ -1468,9 +1383,7 @@ class DPM_Solver: ) elif method == "multistep": assert steps >= order - timesteps = self.get_time_steps( - skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device - ) + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) assert timesteps.shape[0] - 1 == steps # Init the initial values. step = 0 @@ -1527,10 +1440,7 @@ class DPM_Solver: model_prev_list[-1] = self.model_fn(x, t) elif method in ["singlestep", "singlestep_fixed"]: if method == "singlestep": - ( - timesteps_outer, - orders, - ) = self.get_orders_and_timesteps_for_singlestep_solver( + (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver( steps=steps, order=order, skip_type=skip_type, @@ -1543,9 +1453,7 @@ class DPM_Solver: orders = [ order, ] * K - timesteps_outer = self.get_time_steps( - skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device - ) + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) for step, order in enumerate(orders): s, t = timesteps_outer[step], timesteps_outer[step + 1] timesteps_inner = self.get_time_steps( @@ -1559,9 +1467,7 @@ class DPM_Solver: h = lambda_inner[-1] - lambda_inner[0] r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update( - x, s, t, order, solver_type=solver_type, r1=r1, r2=r2 - ) + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) if self.correcting_xt_fn is not None: x = self.correcting_xt_fn(x, t, step) if return_intermediate: @@ -1613,9 +1519,7 @@ def interpolate_fn(x, xp, yp): cand_start_idx, ), ) - end_idx = torch.where( - torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1 - ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) start_idx2 = torch.where( @@ -1628,12 +1532,8 @@ def interpolate_fn(x, xp, yp): ), ) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather( - y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2) - ).squeeze(2) - end_y = torch.gather( - y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2) - ).squeeze(2) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) return cand diff --git a/TTS/tts/layers/tortoise/random_latent_generator.py b/TTS/tts/layers/tortoise/random_latent_generator.py index 88bb0880..9b39c1e4 100644 --- a/TTS/tts/layers/tortoise/random_latent_generator.py +++ b/TTS/tts/layers/tortoise/random_latent_generator.py @@ -40,8 +40,7 @@ class RandomLatentConverter(nn.Module): def __init__(self, channels): super().__init__() self.layers = nn.Sequential( - *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], - nn.Linear(channels, channels) + *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels) ) self.channels = channels diff --git a/TTS/tts/layers/tortoise/tokenizer.py b/TTS/tts/layers/tortoise/tokenizer.py index 88738941..1f399148 100644 --- a/TTS/tts/layers/tortoise/tokenizer.py +++ b/TTS/tts/layers/tortoise/tokenizer.py @@ -95,9 +95,7 @@ def _expand_number(m): elif num % 100 == 0: return _inflect.number_to_words(num // 100) + " hundred" else: - return _inflect.number_to_words( - num, andword="", zero="oh", group=2 - ).replace(", ", " ") + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") else: return _inflect.number_to_words(num, andword="") @@ -165,9 +163,7 @@ def lev_distance(s1, s2): if c1 == c2: distances_.append(distances[i1]) else: - distances_.append( - 1 + min((distances[i1], distances[i1 + 1], distances_[-1])) - ) + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) distances = distances_ return distances[-1] diff --git a/TTS/tts/layers/tortoise/transformer.py b/TTS/tts/layers/tortoise/transformer.py index 6f5bf5a3..70d46aa3 100644 --- a/TTS/tts/layers/tortoise/transformer.py +++ b/TTS/tts/layers/tortoise/transformer.py @@ -36,12 +36,8 @@ def route_args(router, args, depth): for key in matched_keys: val = args[key] - for depth, ((f_args, g_args), routes) in enumerate( - zip(routed_args, router[key]) - ): - new_f_args, new_g_args = map( - lambda route: ({key: val} if route else {}), routes - ) + for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): + new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes) routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) return routed_args @@ -217,12 +213,8 @@ class Transformer(nn.Module): layers.append( nn.ModuleList( [ - LayerScale( - dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm) - ), - LayerScale( - dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm) - ), + LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)), + LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)), ] ) ) diff --git a/TTS/tts/layers/tortoise/utils.py b/TTS/tts/layers/tortoise/utils.py index e8140a03..151ea803 100644 --- a/TTS/tts/layers/tortoise/utils.py +++ b/TTS/tts/layers/tortoise/utils.py @@ -1,5 +1,7 @@ import os -try: import gdown + +try: + import gdown except ImportError: raise ImportError( "Sorry, gdown is required in order to download the new BigVGAN vocoder.\n" @@ -11,9 +13,7 @@ import progressbar D_STEM = "https://drive.google.com/uc?id=" -DEFAULT_MODELS_DIR = os.path.join( - os.path.expanduser("~"), ".cache", "tortoise", "models" -) +DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) MODELS = { "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth", @@ -30,6 +30,8 @@ MODELS = { } pbar = None + + def download_models(specific_models=None): """ Call to download all the models that Tortoise uses. @@ -62,6 +64,7 @@ def download_models(specific_models=None): request.urlretrieve(url, model_path, show_progress) print("Done.") + def get_model_path(model_name, models_dir=MODELS_DIR): """ Get path to given model, download it if it doesn't exist. @@ -71,4 +74,4 @@ def get_model_path(model_name, models_dir=MODELS_DIR): model_path = os.path.join(models_dir, model_name) if not os.path.exists(model_path) and models_dir == MODELS_DIR: download_models([model_name]) - return model_path \ No newline at end of file + return model_path diff --git a/TTS/tts/layers/tortoise/vocoder.py b/TTS/tts/layers/tortoise/vocoder.py index c8a145d1..47365eb5 100644 --- a/TTS/tts/layers/tortoise/vocoder.py +++ b/TTS/tts/layers/tortoise/vocoder.py @@ -1,12 +1,12 @@ +import json +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Optional + import torch import torch.nn as nn import torch.nn.functional as F -import json -from enum import Enum -from typing import Optional, Callable -from dataclasses import dataclass - MAX_WAV_VALUE = 32768.0 @@ -40,18 +40,12 @@ class KernelPredictor(torch.nn.Module): self.conv_kernel_size = conv_kernel_size self.conv_layers = conv_layers - kpnet_kernel_channels = ( - conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers - ) # l_w + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w kpnet_bias_channels = conv_out_channels * conv_layers # l_b self.input_conv = nn.Sequential( - nn.utils.weight_norm( - nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) - ), - getattr(nn, kpnet_nonlinear_activation)( - **kpnet_nonlinear_activation_params - ), + nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), ) self.residual_convs = nn.ModuleList() @@ -69,9 +63,7 @@ class KernelPredictor(torch.nn.Module): bias=True, ) ), - getattr(nn, kpnet_nonlinear_activation)( - **kpnet_nonlinear_activation_params - ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), nn.utils.weight_norm( nn.Conv1d( kpnet_hidden_channels, @@ -81,9 +73,7 @@ class KernelPredictor(torch.nn.Module): bias=True, ) ), - getattr(nn, kpnet_nonlinear_activation)( - **kpnet_nonlinear_activation_params - ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), ) ) self.kernel_conv = nn.utils.weight_norm( @@ -252,17 +242,11 @@ class LVCBlock(torch.nn.Module): """ batch, _, in_length = x.shape batch, _, out_channels, kernel_size, kernel_length = kernel.shape - assert in_length == ( - kernel_length * hop_size - ), "length of (x, kernel) is not matched" + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" padding = dilation * int((kernel_size - 1) / 2) - x = F.pad( - x, (padding, padding), "constant", 0 - ) # (batch, in_channels, in_length + 2*padding) - x = x.unfold( - 2, hop_size + 2 * padding, hop_size - ) # (batch, in_channels, kernel_length, hop_size + 2*padding) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) if hop_size < dilation: x = F.pad(x, (0, dilation), "constant", 0) @@ -270,12 +254,8 @@ class LVCBlock(torch.nn.Module): 3, dilation, dilation ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) x = x[:, :, :, :, :hop_size] - x = x.transpose( - 3, 4 - ) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) - x = x.unfold( - 4, kernel_size, 1 - ) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) o = torch.einsum("bildsk,biokl->bolsd", x, kernel) o = o.to(memory_format=torch.channels_last_3d) @@ -334,15 +314,11 @@ class UnivNetGenerator(nn.Module): ) ) - self.conv_pre = nn.utils.weight_norm( - nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect") - ) + self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")) self.conv_post = nn.Sequential( nn.LeakyReLU(lReLU_slope), - nn.utils.weight_norm( - nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect") - ), + nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")), nn.Tanh(), ) @@ -399,12 +375,16 @@ class VocType: constructor: Callable[[], nn.Module] model_path: str subkey: Optional[str] = None + def optionally_index(self, model_dict): if self.subkey is not None: return model_dict[self.subkey] return model_dict + + class VocConf(Enum): - Univnet = VocType(UnivNetGenerator, "vocoder.pth", 'model_g') + Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g") + if __name__ == "__main__": model = UnivNetGenerator() diff --git a/TTS/tts/layers/tortoise/wav2vec_alignment.py b/TTS/tts/layers/tortoise/wav2vec_alignment.py index c76d4daf..47456cc5 100644 --- a/TTS/tts/layers/tortoise/wav2vec_alignment.py +++ b/TTS/tts/layers/tortoise/wav2vec_alignment.py @@ -12,9 +12,7 @@ def max_alignment(s1, s2, skip_character="~", record=None): """ if record is None: record = {} - assert ( - skip_character not in s1 - ), f"Found the skip character {skip_character} in the provided string, {s1}" + assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}" if len(s1) == 0: return "" if len(s2) == 0: @@ -49,15 +47,9 @@ class Wav2VecAlignment: """ def __init__(self, device="cuda"): - self.model = Wav2Vec2ForCTC.from_pretrained( - "jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli" - ).cpu() - self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - "facebook/wav2vec2-large-960h" - ) - self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained( - "jbetker/tacotron-symbols" - ) + self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h") + self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols") self.device = device def align(self, audio, expected_text, audio_sample_rate=24000): @@ -117,9 +109,7 @@ class Wav2VecAlignment: ) # Now fix up alignments. Anything with -1 should be interpolated. - alignments.append( - orig_len - ) # This'll get removed but makes the algorithm below more readable. + alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable. for i in range(len(alignments)): if alignments[i] == -1: for j in range(i + 1, len(alignments)): @@ -128,9 +118,7 @@ class Wav2VecAlignment: break for j in range(i, next_found_token): gap = alignments[next_found_token] - alignments[i - 1] - alignments[j] = (j - i + 1) * gap // ( - next_found_token - i + 1 - ) + alignments[i - 1] + alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1] return alignments[:-1] @@ -140,9 +128,7 @@ class Wav2VecAlignment: splitted = expected_text.split("[") fully_split = [splitted[0]] for spl in splitted[1:]: - assert ( - "]" in spl - ), 'Every "[" character must be paired with a "]" with no nesting.' + assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.' fully_split.extend(spl.split("]")) # At this point, fully_split is a list of strings, with every other string being something that should be redacted. diff --git a/TTS/tts/layers/tortoise/xtransformers.py b/TTS/tts/layers/tortoise/xtransformers.py index 8be2df45..0c6a70d6 100644 --- a/TTS/tts/layers/tortoise/xtransformers.py +++ b/TTS/tts/layers/tortoise/xtransformers.py @@ -6,24 +6,25 @@ from inspect import isfunction import torch import torch.nn.functional as F from einops import rearrange, repeat -from torch import nn, einsum +from torch import einsum, nn DEFAULT_DIM_HEAD = 64 -Intermediates = namedtuple('Intermediates', [ - 'pre_softmax_attn', - 'post_softmax_attn' -]) +Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) -LayerIntermediates = namedtuple('Intermediates', [ - 'hiddens', - 'attn_intermediates', - 'past_key_values', -]) +LayerIntermediates = namedtuple( + "Intermediates", + [ + "hiddens", + "attn_intermediates", + "past_key_values", + ], +) # helpers + def exists(val): return val is not None @@ -38,7 +39,7 @@ def cast_tuple(val, depth): return val if isinstance(val, tuple) else (val,) * depth -class always(): +class always: def __init__(self, val): self.val = val @@ -46,7 +47,7 @@ class always(): return self.val -class not_equals(): +class not_equals: def __init__(self, val): self.val = val @@ -54,7 +55,7 @@ class not_equals(): return x != self.val -class equals(): +class equals: def __init__(self, val): self.val = val @@ -72,14 +73,16 @@ def l2norm(t): # init helpers + def init_zero_(layer): - nn.init.constant_(layer.weight, 0.) + nn.init.constant_(layer.weight, 0.0) if exists(layer.bias): - nn.init.constant_(layer.bias, 0.) + nn.init.constant_(layer.bias, 0.0) # keyword argument helpers + def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) @@ -104,12 +107,13 @@ def group_by_key_prefix(prefix, d): def groupby_prefix_and_trim(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))) return kwargs_without_prefix, kwargs # activations + class ReluSquared(nn.Module): def forward(self, x): return F.relu(x) ** 2 @@ -117,30 +121,31 @@ class ReluSquared(nn.Module): # positional embeddings + class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.emb = nn.Embedding(max_seq_len, dim) def forward(self, x): n = torch.arange(x.shape[1], device=x.device) pos_emb = self.emb(n) - pos_emb = rearrange(pos_emb, 'n d -> () n d') + pos_emb = rearrange(pos_emb, "n d -> () n d") return pos_emb * self.scale class FixedPositionalEmbedding(nn.Module): def __init__(self, dim): super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) def forward(self, x, seq_dim=1, offset=0): t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset - sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) - return rearrange(emb, 'n d -> () n d') + return rearrange(emb, "n d -> () n d") class RelativePositionBias(nn.Module): @@ -166,9 +171,10 @@ class RelativePositionBias(nn.Module): max_exact = num_buckets // 2 is_small = n < max_exact - val_if_large = max_exact + ( - torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - ).long() + val_if_large = ( + max_exact + + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long() + ) val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) @@ -179,10 +185,11 @@ class RelativePositionBias(nn.Module): q_pos = torch.arange(i, dtype=torch.long, device=device) k_pos = torch.arange(j, dtype=torch.long, device=device) rel_pos = k_pos[None, :] - q_pos[:, None] - rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, - max_distance=self.max_distance) + rp_bucket = self._relative_position_bucket( + rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance + ) values = self.relative_attention_bias(rp_bucket) - bias = rearrange(values, 'i j h -> () h i j') + bias = rearrange(values, "i j h -> () h i j") return qk_dots + (bias * self.scale) @@ -191,23 +198,25 @@ class AlibiPositionalBias(nn.Module): super().__init__() self.heads = heads slopes = torch.Tensor(self._get_slopes(heads)) - slopes = rearrange(slopes, 'h -> () h () ()') - self.register_buffer('slopes', slopes, persistent=False) - self.register_buffer('bias', None, persistent=False) + slopes = rearrange(slopes, "h -> () h () ()") + self.register_buffer("slopes", slopes, persistent=False) + self.register_buffer("bias", None, persistent=False) @staticmethod def _get_slopes(heads): def get_slopes_power_of_2(n): - start = (2 ** (-2 ** -(math.log2(n) - 3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start - return [start * ratio ** i for i in range(n)] + return [start * ratio**i for i in range(n)] if math.log2(heads).is_integer(): return get_slopes_power_of_2(heads) closest_power_of_2 = 2 ** math.floor(math.log2(heads)) - return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ - :heads - closest_power_of_2] + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: heads - closest_power_of_2] + ) def forward(self, qk_dots): h, i, j, device = *qk_dots.shape[-3:], qk_dots.device @@ -216,13 +225,13 @@ class AlibiPositionalBias(nn.Module): return qk_dots + self.bias[..., :j] bias = torch.arange(j, device=device) - bias = rearrange(bias, 'j -> () () () j') + bias = rearrange(bias, "j -> () () () j") bias = bias * self.slopes num_heads_unalibied = h - bias.shape[1] bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) - self.register_buffer('bias', bias, persistent=False) + self.register_buffer("bias", bias, persistent=False) return qk_dots + self.bias @@ -247,8 +256,8 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias): else: i_arange = torch.arange(i, device=device) j_arange = torch.arange(j, device=device) - bias = rearrange(j_arange, 'j -> 1 1 1 j') - rearrange(i_arange, 'i -> 1 1 i 1') - self.register_buffer('bias', bias, persistent=False) + bias = rearrange(j_arange, "j -> 1 1 1 j") - rearrange(i_arange, "i -> 1 1 i 1") + self.register_buffer("bias", bias, persistent=False) if self.bidirectional: past_slopes = get_slopes(self.learned_logslopes) @@ -264,18 +273,18 @@ class LearnedAlibiPositionalBias(AlibiPositionalBias): class RotaryEmbedding(nn.Module): def __init__(self, dim): super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) def forward(self, max_seq_len, device): t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq) - freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - return rearrange(emb, 'n d -> () () n d') + return rearrange(emb, "n d -> () () n d") def rotate_half(x): - x = rearrange(x, '... (j d) -> ... j d', j=2) + x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) @@ -288,6 +297,7 @@ def apply_rotary_pos_emb(t, freqs): # norms + class Scale(nn.Module): def __init__(self, value, fn): super().__init__() @@ -323,7 +333,7 @@ class Rezero(nn.Module): class ScaleNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) @@ -335,7 +345,7 @@ class ScaleNorm(nn.Module): class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-8): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) @@ -347,7 +357,7 @@ class RMSNorm(nn.Module): class RMSScaleShiftNorm(nn.Module): def __init__(self, dim, eps=1e-8): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) self.scale_shift_process = nn.Linear(dim * 2, dim * 2) @@ -364,6 +374,7 @@ class RMSScaleShiftNorm(nn.Module): # residual and residual gates + class Residual(nn.Module): def __init__(self, dim, scale_residual=False): super().__init__() @@ -386,24 +397,22 @@ class GRUGating(nn.Module): if exists(self.residual_scale): residual = residual * self.residual_scale - gated_output = self.gru( - rearrange(x, 'b n d -> (b n) d'), - rearrange(residual, 'b n d -> (b n) d') - ) + gated_output = self.gru(rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")) return gated_output.reshape_as(x) # token shifting + def shift(t, amount, mask=None): if amount == 0: return t if exists(mask): - t = t.masked_fill(~mask[..., None], 0.) + t = t.masked_fill(~mask[..., None], 0.0) - return F.pad(t, (0, 0, amount, -amount), value=0.) + return F.pad(t, (0, 0, amount, -amount), value=0.0) class ShiftTokens(nn.Module): @@ -413,7 +422,7 @@ class ShiftTokens(nn.Module): self.shifts = tuple(shifts) def forward(self, x, **kwargs): - mask = kwargs.get('mask', None) + mask = kwargs.get("mask", None) shifts = self.shifts segments = len(shifts) feats_per_shift = x.shape[-1] // segments @@ -426,6 +435,7 @@ class ShiftTokens(nn.Module): # feedforward + class GLU(nn.Module): def __init__(self, dim_in, dim_out, activation): super().__init__() @@ -439,31 +449,30 @@ class GLU(nn.Module): class FeedForward(nn.Module): def __init__( - self, - dim, - dim_out=None, - mult=4, - glu=False, - relu_squared=False, - post_act_ln=False, - dropout=0., - zero_init_output=False + self, + dim, + dim_out=None, + mult=4, + glu=False, + relu_squared=False, + post_act_ln=False, + dropout=0.0, + zero_init_output=False, ): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) activation = ReluSquared() if relu_squared else nn.GELU() - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - activation - ) if not glu else GLU(dim, inner_dim, activation) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), activation) if not glu else GLU(dim, inner_dim, activation) + ) self.net = nn.Sequential( project_in, nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + nn.Linear(inner_dim, dim_out), ) # init last linear layer to 0 @@ -476,33 +485,34 @@ class FeedForward(nn.Module): # attention. + class Attention(nn.Module): def __init__( - self, - dim, - dim_head=DEFAULT_DIM_HEAD, - heads=8, - causal=False, - talking_heads=False, - head_scale=False, - collab_heads=False, - collab_compression=.3, - sparse_topk=None, - use_entmax15=False, - num_mem_kv=0, - dropout=0., - on_attn=False, - gate_values=False, - zero_init_output=False, - max_attend_past=None, - qk_norm=False, - scale_init_value=None, - rel_pos_bias=False, - rel_pos_num_buckets=32, - rel_pos_max_distance=128, + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + talking_heads=False, + head_scale=False, + collab_heads=False, + collab_compression=0.3, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + gate_values=False, + zero_init_output=False, + max_attend_past=None, + qk_norm=False, + scale_init_value=None, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, ): super().__init__() - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.causal = causal @@ -532,8 +542,9 @@ class Attention(nn.Module): # cosine sim attention self.qk_norm = qk_norm if qk_norm: - scale_init_value = default(scale_init_value, - -3) # if not provided, initialize as though it were sequence length of 1024 + scale_init_value = default( + scale_init_value, -3 + ) # if not provided, initialize as though it were sequence length of 1024 self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) # talking heads @@ -565,29 +576,44 @@ class Attention(nn.Module): self.rel_pos_bias = rel_pos_bias if rel_pos_bias: - assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' - self.rel_pos = RelativePositionBias(scale=dim_head ** 0.5, causal=causal, heads=heads, - num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance) + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" + self.rel_pos = RelativePositionBias( + scale=dim_head**0.5, + causal=causal, + heads=heads, + num_buckets=rel_pos_num_buckets, + max_distance=rel_pos_max_distance, + ) # init output projection 0 if zero_init_output: init_zero_(self.to_out) def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - attn_mask=None, - sinusoidal_emb=None, - rotary_pos_emb=None, - prev_attn=None, - mem=None, - layer_past=None, + self, + x, + context=None, + mask=None, + context_mask=None, + attn_mask=None, + sinusoidal_emb=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None, + layer_past=None, ): - b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, self.scale, x.device, exists( - context) + b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = ( + *x.shape, + self.heads, + self.talking_heads, + self.collab_heads, + self.head_scale, + self.scale, + x.device, + exists(context), + ) kv_input = default(context, x) q_input = x @@ -609,11 +635,11 @@ class Attention(nn.Module): v = self.to_v(v_input) if not collab_heads: - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) else: - q = einsum('b i d, h d -> b h i d', q, self.collab_mixing) - k = rearrange(k, 'b n d -> b () n d') - v = rearrange(v, 'b n (h d) -> b h n d', h=h) + q = einsum("b i d, h d -> b h i d", q, self.collab_mixing) + k = rearrange(k, "b n d -> b () n d") + v = rearrange(v, "b n (h d) -> b h n d", h=h) if layer_past is not None: past_key, past_value = layer_past @@ -633,12 +659,12 @@ class Attention(nn.Module): q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) k_mask = q_mask if not exists(context) else context_mask k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) - q_mask = rearrange(q_mask, 'b i -> b () i ()') - k_mask = rearrange(k_mask, 'b j -> b () () j') + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") input_mask = q_mask * k_mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + mem_k, mem_v = map(lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v)) k = torch.cat((mem_k, k), dim=-2) v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): @@ -651,7 +677,7 @@ class Attention(nn.Module): q, k = map(l2norm, (q, k)) scale = 1 / (self.scale.exp().clamp(min=1e-2)) - dots = einsum('b h i d, b h j d -> b h i j', q, k) * scale + dots = einsum("b h i d, b h j d -> b h i j", q, k) * scale mask_value = max_neg_value(dots) if exists(prev_attn): @@ -660,7 +686,7 @@ class Attention(nn.Module): pre_softmax_attn = dots.clone() if talking_heads: - dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + dots = einsum("b h i j, h k -> b k i j", dots, self.pre_softmax_proj).contiguous() if self.rel_pos_bias: dots = self.rel_pos(dots) @@ -670,18 +696,20 @@ class Attention(nn.Module): del input_mask if exists(attn_mask): - assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4' + assert ( + 2 <= attn_mask.ndim <= 4 + ), "attention mask must have greater than 2 dimensions but less than or equal to 4" if attn_mask.ndim == 2: - attn_mask = rearrange(attn_mask, 'i j -> () () i j') + attn_mask = rearrange(attn_mask, "i j -> () () i j") elif attn_mask.ndim == 3: - attn_mask = rearrange(attn_mask, 'h i j -> () h i j') + attn_mask = rearrange(attn_mask, "h i j -> () h i j") dots.masked_fill_(~attn_mask, mask_value) if exists(self.max_attend_past): i, j = dots.shape[-2:] range_q = torch.arange(j - i, j, device=device) range_k = torch.arange(j, device=device) - dist = rearrange(range_q, 'i -> () () i ()') - rearrange(range_k, 'j -> () () () j') + dist = rearrange(range_q, "i -> () () i ()") - rearrange(range_k, "j -> () () () j") mask = dist > self.max_attend_past dots.masked_fill_(mask, mask_value) del mask @@ -689,7 +717,7 @@ class Attention(nn.Module): if self.causal: i, j = dots.shape[-2:] r = torch.arange(i, device=device) - mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") mask = F.pad(mask, (j - i, 0), value=False) dots.masked_fill_(mask, mask_value) del mask @@ -707,74 +735,71 @@ class Attention(nn.Module): attn = self.dropout(attn) if talking_heads: - attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + attn = einsum("b h i j, h k -> b k i j", attn, self.post_softmax_proj).contiguous() - out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = einsum("b h i j, b h j d -> b h i d", attn, v) if head_scale: out = out * self.head_scale_params - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") if exists(self.to_v_gate): gates = self.to_v_gate(x) out = out * gates.sigmoid() - intermediates = Intermediates( - pre_softmax_attn=pre_softmax_attn, - post_softmax_attn=post_softmax_attn - ) + intermediates = Intermediates(pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn) return self.to_out(out), intermediates, k_cache, v_cache class AttentionLayers(nn.Module): def __init__( - self, - dim, - depth, - heads=8, - causal=False, - cross_attend=False, - only_cross=False, - use_scalenorm=False, - use_rms_scaleshift_norm=False, - use_rmsnorm=False, - use_rezero=False, - alibi_pos_bias=False, - alibi_num_heads=None, - alibi_learned=False, - position_infused_attn=False, - rotary_pos_emb=False, - rotary_emb_dim=None, - custom_layers=None, - sandwich_coef=None, - par_ratio=None, - residual_attn=False, - cross_residual_attn=False, - macaron=False, - pre_norm=True, - gate_residual=False, - scale_residual=False, - shift_tokens=0, - sandwich_norm=False, - use_qk_norm_attn=False, - qk_norm_attn_seq_len=None, - zero_init_branch_output=False, - **kwargs + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rms_scaleshift_norm=False, + use_rmsnorm=False, + use_rezero=False, + alibi_pos_bias=False, + alibi_num_heads=None, + alibi_learned=False, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + scale_residual=False, + shift_tokens=0, + sandwich_norm=False, + use_qk_norm_attn=False, + qk_norm_attn_seq_len=None, + zero_init_branch_output=False, + **kwargs, ): super().__init__() - ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) - attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) - dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) self.dim = dim self.depth = depth self.layers = nn.ModuleList([]) self.causal = causal - rel_pos_bias = 'rel_pos_bias' in attn_kwargs + rel_pos_bias = "rel_pos_bias" in attn_kwargs self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None @@ -782,17 +807,18 @@ class AttentionLayers(nn.Module): self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None assert not ( - alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' + alibi_pos_bias and rel_pos_bias + ), "you can only choose Alibi positional bias or T5 relative positional bias, not both" if alibi_pos_bias: alibi_num_heads = default(alibi_num_heads, heads) - assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' + assert alibi_num_heads <= heads, "number of ALiBi heads must be less than the total number of heads" alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) else: self.rel_pos = None - assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' + assert not (not pre_norm and sandwich_norm), "sandwich norm cannot be used when not using prenorm" self.pre_norm = pre_norm self.sandwich_norm = sandwich_norm @@ -809,27 +835,30 @@ class AttentionLayers(nn.Module): branch_fn = Rezero if use_rezero else None if cross_attend and not only_cross: - default_block = ('a', 'c', 'f') + default_block = ("a", "c", "f") elif cross_attend and only_cross: - default_block = ('c', 'f') + default_block = ("c", "f") else: - default_block = ('a', 'f') + default_block = ("a", "f") if macaron: - default_block = ('f',) + default_block + default_block = ("f",) + default_block # qk normalization if use_qk_norm_attn: - attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists( - qk_norm_attn_seq_len) else None - attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} + attn_scale_init_value = ( + -math.log(math.log2(qk_norm_attn_seq_len**2 - qk_norm_attn_seq_len)) + if exists(qk_norm_attn_seq_len) + else None + ) + attn_kwargs = {**attn_kwargs, "qk_norm": True, "scale_init_value": attn_scale_init_value} # zero init if zero_init_branch_output: - attn_kwargs = {**attn_kwargs, 'zero_init_output': True} - ff_kwargs = {**ff_kwargs, 'zero_init_output': True} + attn_kwargs = {**attn_kwargs, "zero_init_output": True} + ff_kwargs = {**ff_kwargs, "zero_init_output": True} # calculate layer block order @@ -837,23 +866,23 @@ class AttentionLayers(nn.Module): layer_types = custom_layers elif exists(par_ratio): par_depth = depth * len(default_block) - assert 1 < par_ratio <= par_depth, 'par ratio out of range' - default_block = tuple(filter(not_equals('f'), default_block)) + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) par_attn = par_depth // par_ratio depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) + assert len(default_block) <= par_width, "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) par_head = par_block * par_attn - layer_types = par_head + ('f',) * (par_depth - len(par_head)) + layer_types = par_head + ("f",) * (par_depth - len(par_head)) elif exists(sandwich_coef): - assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + assert sandwich_coef > 0 and sandwich_coef <= depth, "sandwich coefficient should be less than the depth" + layer_types = ("a",) * sandwich_coef + default_block * (depth - sandwich_coef) + ("f",) * sandwich_coef else: layer_types = default_block * depth self.layer_types = layer_types - self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) # calculate token shifting @@ -864,15 +893,15 @@ class AttentionLayers(nn.Module): for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): is_last_layer = ind == (len(self.layer_types) - 1) - if layer_type == 'a': + if layer_type == "a": layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) - elif layer_type == 'c': + elif layer_type == "c": layer = Attention(dim, heads=heads, **attn_kwargs) - elif layer_type == 'f': + elif layer_type == "f": layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer) else: - raise Exception(f'invalid layer type {layer_type}') + raise Exception(f"invalid layer type {layer_type}") if layer_shift_tokens > 0: shift_range_upper = layer_shift_tokens + 1 @@ -885,42 +914,35 @@ class AttentionLayers(nn.Module): residual_fn = GRUGating if gate_residual else Residual residual = residual_fn(dim, scale_residual=scale_residual) - layer_uses_qk_norm = use_qk_norm_attn and layer_type in ('a', 'c') + layer_uses_qk_norm = use_qk_norm_attn and layer_type in ("a", "c") pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None - norms = nn.ModuleList([ - pre_branch_norm, - post_branch_norm, - post_main_norm - ]) + norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm]) - self.layers.append(nn.ModuleList([ - norms, - layer, - residual - ])) + self.layers.append(nn.ModuleList([norms, layer, residual])) def forward( - self, - x, - context=None, - full_context=None, # for passing a list of hidden states from an encoder - mask=None, - context_mask=None, - attn_mask=None, - mems=None, - return_hiddens=False, - norm_scale_shift_inp=None, - past_key_values=None, - expected_seq_len=None, + self, + x, + context=None, + full_context=None, # for passing a list of hidden states from an encoder + mask=None, + context_mask=None, + attn_mask=None, + mems=None, + return_hiddens=False, + norm_scale_shift_inp=None, + past_key_values=None, + expected_seq_len=None, ): - assert not (self.cross_attend ^ (exists(context) or exists( - full_context))), 'context must be passed in if cross_attend is set to True' - assert context is None or full_context is None, 'only one of full_context or context can be provided' + assert not ( + self.cross_attend ^ (exists(context) or exists(full_context)) + ), "context must be passed in if cross_attend is set to True" + assert context is None or full_context is None, "only one of full_context or context can be provided" hiddens = [] intermediates = [] @@ -930,24 +952,28 @@ class AttentionLayers(nn.Module): mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers norm_args = {} if exists(norm_scale_shift_inp): - norm_args['norm_scale_shift_inp'] = norm_scale_shift_inp + norm_args["norm_scale_shift_inp"] = norm_scale_shift_inp rotary_pos_emb = None if exists(self.rotary_pos_emb): if not self.training and self.causal: - assert expected_seq_len is not None, "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" + assert ( + expected_seq_len is not None + ), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" elif expected_seq_len is None: expected_seq_len = 0 seq_len = x.shape[1] if past_key_values is not None: seq_len += past_key_values[0][0].shape[-2] - max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len]) + max_rotary_emb_length = max( + list(map(lambda m: (m.shape[1] if exists(m) else 0) + seq_len, mems)) + [expected_seq_len] + ) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) present_key_values = [] cross_attn_count = 0 for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): - if layer_type == 'a': + if layer_type == "a": layer_mem = mems.pop(0) if mems else None residual = x @@ -957,26 +983,39 @@ class AttentionLayers(nn.Module): if exists(pre_branch_norm): x = pre_branch_norm(x, **norm_args) - if layer_type == 'a' or layer_type == 'c': + if layer_type == "a" or layer_type == "c": if past_key_values is not None: layer_kv = past_key_values.pop(0) layer_past = tuple(s.to(x.device) for s in layer_kv) else: layer_past = None - if layer_type == 'a': - out, inter, k, v = block(x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, - prev_attn, layer_mem, layer_past) - elif layer_type == 'c': + if layer_type == "a": + out, inter, k, v = block( + x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem, layer_past + ) + elif layer_type == "c": if exists(full_context): - out, inter, k, v = block(x, full_context[cross_attn_count], mask, context_mask, None, None, - None, prev_attn, None, layer_past) + out, inter, k, v = block( + x, + full_context[cross_attn_count], + mask, + context_mask, + None, + None, + None, + prev_attn, + None, + layer_past, + ) else: - out, inter, k, v = block(x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past) - elif layer_type == 'f': + out, inter, k, v = block( + x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past + ) + elif layer_type == "f": out = block(x) - if layer_type == 'a' or layer_type == 'c' and present_key_values is not None: + if layer_type == "a" or layer_type == "c" and present_key_values is not None: present_key_values.append((k.detach(), v.detach())) if exists(post_branch_norm): @@ -984,28 +1023,26 @@ class AttentionLayers(nn.Module): x = residual_fn(out, residual) - if layer_type in ('a', 'c'): + if layer_type in ("a", "c"): intermediates.append(inter) - if layer_type == 'a' and self.residual_attn: + if layer_type == "a" and self.residual_attn: prev_attn = inter.pre_softmax_attn - elif layer_type == 'c' and self.cross_residual_attn: + elif layer_type == "c" and self.cross_residual_attn: prev_cross_attn = inter.pre_softmax_attn if exists(post_main_norm): x = post_main_norm(x, **norm_args) - if layer_type == 'c': + if layer_type == "c": cross_attn_count += 1 - if layer_type == 'f': + if layer_type == "f": hiddens.append(x) if return_hiddens: intermediates = LayerIntermediates( - hiddens=hiddens, - attn_intermediates=intermediates, - past_key_values=present_key_values + hiddens=hiddens, attn_intermediates=intermediates, past_key_values=present_key_values ) return x, intermediates @@ -1015,13 +1052,13 @@ class AttentionLayers(nn.Module): class Encoder(AttentionLayers): def __init__(self, **kwargs): - assert 'causal' not in kwargs, 'cannot set causality on encoder' + assert "causal" not in kwargs, "cannot set causality on encoder" super().__init__(causal=False, **kwargs) class Decoder(AttentionLayers): def __init__(self, **kwargs): - assert 'causal' not in kwargs, 'cannot set causality on decoder' + assert "causal" not in kwargs, "cannot set causality on decoder" super().__init__(causal=True, **kwargs) @@ -1031,22 +1068,13 @@ class CrossAttender(AttentionLayers): class ViTransformerWrapper(nn.Module): - def __init__( - self, - *, - image_size, - patch_size, - attn_layers, - num_classes=None, - dropout=0., - emb_dropout=0. - ): + def __init__(self, *, image_size, patch_size, attn_layers, num_classes=None, dropout=0.0, emb_dropout=0.0): super().__init__() - assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' - assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' + assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder" + assert image_size % patch_size == 0, "image dimensions must be divisible by the patch size" dim = attn_layers.dim num_patches = (image_size // patch_size) ** 2 - patch_dim = 3 * patch_size ** 2 + patch_dim = 3 * patch_size**2 self.patch_size = patch_size @@ -1059,20 +1087,16 @@ class ViTransformerWrapper(nn.Module): self.norm = nn.LayerNorm(dim) self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None - def forward( - self, - img, - return_embeddings=False - ): + def forward(self, img, return_embeddings=False): p = self.patch_size - x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) + x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p) x = self.patch_to_embedding(x) b, n, _ = x.shape - cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) + cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) x = torch.cat((cls_tokens, x), dim=1) - x = x + self.pos_embedding[:, :(n + 1)] + x = x + self.pos_embedding[:, : (n + 1)] x = self.dropout(x) x = self.attn_layers(x) @@ -1086,21 +1110,21 @@ class ViTransformerWrapper(nn.Module): class TransformerWrapper(nn.Module): def __init__( - self, - *, - num_tokens, - max_seq_len, - attn_layers, - emb_dim=None, - max_mem_len=0., - shift_mem_down=0, - emb_dropout=0., - num_memory_tokens=None, - tie_embedding=False, - use_pos_emb=True + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + shift_mem_down=0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, ): super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" dim = attn_layers.dim emb_dim = default(emb_dim, dim) @@ -1110,8 +1134,11 @@ class TransformerWrapper(nn.Module): self.shift_mem_down = shift_mem_down self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) self.emb_dropout = nn.Dropout(emb_dropout) self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() @@ -1132,15 +1159,15 @@ class TransformerWrapper(nn.Module): nn.init.kaiming_normal_(self.token_emb.weight) def forward( - self, - x, - return_embeddings=False, - mask=None, - return_hiddens=False, - return_attn=False, - mems=None, - use_cache=False, - **kwargs + self, + x, + return_embeddings=False, + mask=None, + return_hiddens=False, + return_attn=False, + mems=None, + use_cache=False, + **kwargs, ): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens x = self.token_emb(x) @@ -1150,7 +1177,7 @@ class TransformerWrapper(nn.Module): x = self.project_emb(x) if num_mem > 0: - mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + mem = repeat(self.memory_tokens, "n d -> b n d", b=b) x = torch.cat((mem, x), dim=1) # auto-handle masking after appending memory tokens @@ -1158,7 +1185,7 @@ class TransformerWrapper(nn.Module): mask = F.pad(mask, (num_mem, 0), value=True) if self.shift_mem_down and exists(mems): - mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] + mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :] mems = [*mems_r, *mems_l] x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) @@ -1186,25 +1213,20 @@ class TransformerWrapper(nn.Module): class ContinuousTransformerWrapper(nn.Module): def __init__( - self, - *, - max_seq_len, - attn_layers, - dim_in=None, - dim_out=None, - emb_dim=None, - emb_dropout=0., - use_pos_emb=True + self, *, max_seq_len, attn_layers, dim_in=None, dim_out=None, emb_dim=None, emb_dropout=0.0, use_pos_emb=True ): super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" dim = attn_layers.dim self.max_seq_len = max_seq_len - self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.pos_emb = ( + AbsolutePositionalEmbedding(dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) self.emb_dropout = nn.Dropout(emb_dropout) self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() @@ -1214,16 +1236,7 @@ class ContinuousTransformerWrapper(nn.Module): self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() - def forward( - self, - x, - return_embeddings=False, - mask=None, - return_attn=False, - mems=None, - use_cache=False, - **kwargs - ): + def forward(self, x, return_embeddings=False, mask=None, return_attn=False, mems=None, use_cache=False, **kwargs): b, n, _, device = *x.shape, x.device x = self.project_in(x) @@ -1245,4 +1258,3 @@ class ContinuousTransformerWrapper(nn.Module): if len(res) > 1: return tuple(res) return res[0] - diff --git a/TTS/tts/layers/vits/fram_prior_network.py b/TTS/tts/layers/vits/fram_prior_network.py deleted file mode 100644 index c393e3fc..00000000 --- a/TTS/tts/layers/vits/fram_prior_network.py +++ /dev/null @@ -1,31 +0,0 @@ -from torch import nn -from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock - - -class FramePriorNet(nn.Module): - def __init__( - self, in_channels, out_channels, hidden_channels, kernel_size, num_res_blocks=13, num_conv_blocks=2 - ): - super().__init__() - self.res_blocks = nn.ModuleList() - for idx in range(num_res_blocks): - block = Conv1dBNBlock( - in_channels if idx == 0 else hidden_channels, - out_channels if (idx + 1) == num_res_blocks else hidden_channels, - hidden_channels, - kernel_size, - 1, - num_conv_blocks, - ) - self.res_blocks.append(block) - def forward(self, x, x_mask=None): - if x_mask is None: - x_mask = 1.0 - o = x * x_mask - for block in self.res_blocks: - res = o - o = block(o) - o = o + res - if x_mask is not None: - o = o * x_mask - return o \ No newline at end of file diff --git a/TTS/tts/layers/vits/reference_encoder.py b/TTS/tts/layers/vits/reference_encoder.py deleted file mode 100644 index d62ff0ca..00000000 --- a/TTS/tts/layers/vits/reference_encoder.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F - -class ReferenceEncoder(nn.Module): - """NN module creating a fixed size prosody embedding from a spectrogram. - - inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] - outputs: [batch_size, embedding_dim] - """ - - def __init__(self, num_mel, filter): - super().__init__() - self.num_mel = num_mel - start_index = 2 - end_index = filter / 16 - i = start_index - filt_len = [] - while i <= end_index: - i = i * 2 - filt_len.append(i) - filt_len.append(i) - filters = [1] + filt_len - num_layers = len(filters) - 1 - convs = [ - nn.Conv2d( - in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2) - ) - for i in range(num_layers) - ] - self.convs = nn.ModuleList(convs) - self.training = False - self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) - - post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers) - self.recurrence = nn.LSTM( - input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False - ) - - def forward(self, inputs, input_lengths): - batch_size = inputs.size(0) - x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel] - valid_lengths = input_lengths.float() # [batch_size] - for conv, bn in zip(self.convs, self.bns): - x = conv(x) - x = bn(x) - x = F.relu(x) - - # Create the post conv width mask based on the valid lengths of the output of the convolution. - # The valid lengths for the output of a convolution on varying length inputs is - # ceil(input_length/stride) + 1 for stride=3 and padding=2 - # For example (kernel_size=3, stride=2, padding=2): - # 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d - # _____ - # x _____ - # x _____ - # x ____ - # x - # x x x x -> Output valid length = 4 - # Since every example in te batch is zero padded and therefore have separate valid_lengths, - # we need to mask off all the values AFTER the valid length for each example in the batch. - # Otherwise, the convolutions create noise and a lot of not real information - valid_lengths = (valid_lengths / 2).float() - valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size] - post_conv_max_width = x.size(2) - - mask = torch.arange(post_conv_max_width).to(inputs.device).expand( - len(valid_lengths), post_conv_max_width - ) < valid_lengths.unsqueeze(1) - mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1] - x = x * mask - - x = x.transpose(1, 2) - # x: 4D tensor [batch_size, post_conv_width, - # num_channels==128, post_conv_height] - - post_conv_width = x.size(1) - x = x.contiguous().view(batch_size, post_conv_width, -1) - # x: 3D tensor [batch_size, post_conv_width, - # num_channels*post_conv_height] - - # Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding - post_conv_input_lengths = valid_lengths - packed_seqs = nn.utils.rnn.pack_padded_sequence( - x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False - ) # dynamic rnn sequence padding - self.recurrence.flatten_parameters() - _, (ht, _) = self.recurrence(packed_seqs) - last_output = ht[-1] - - return last_output.to(inputs.device) # [B, 128] \ No newline at end of file diff --git a/TTS/tts/layers/vits/vqvae.py b/TTS/tts/layers/vits/vqvae.py deleted file mode 100644 index 7ebd8692..00000000 --- a/TTS/tts/layers/vits/vqvae.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -from torch.autograd import Function - -class VectorQuantization(Function): - @staticmethod - def forward(ctx, inputs, codebook): - with torch.no_grad(): - embedding_size = codebook.size(1) - inputs_size = inputs.size() - inputs_flatten = inputs.view(-1, embedding_size) - - codebook_sqr = torch.sum(codebook ** 2, dim=1) - inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True) - - # Compute the distances to the codebook - distances = torch.addmm(codebook_sqr + inputs_sqr, - inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0) - - _, indices_flatten = torch.min(distances, dim=1) - indices = indices_flatten.view(*inputs_size[:-1]) - ctx.mark_non_differentiable(indices) - - return indices - - @staticmethod - def backward(ctx, grad_output): - raise RuntimeError('Trying to call `.grad()` on graph containing ' - '`VectorQuantization`. The function `VectorQuantization` ' - 'is not differentiable. Use `VectorQuantizationStraightThrough` ' - 'if you want a straight-through estimator of the gradient.') - -class VectorQuantizationStraightThrough(Function): - @staticmethod - def forward(ctx, inputs, codebook): - indices = vq(inputs, codebook) - indices_flatten = indices.view(-1) - ctx.save_for_backward(indices_flatten, codebook) - ctx.mark_non_differentiable(indices_flatten) - - codes_flatten = torch.index_select(codebook, dim=0, - index=indices_flatten) - codes = codes_flatten.view_as(inputs) - - return (codes, indices_flatten) - - @staticmethod - def backward(ctx, grad_output, grad_indices): - grad_inputs, grad_codebook = None, None - - if ctx.needs_input_grad[0]: - # Straight-through estimator - grad_inputs = grad_output.clone() - if ctx.needs_input_grad[1]: - # Gradient wrt. the codebook - indices, codebook = ctx.saved_tensors - embedding_size = codebook.size(1) - - grad_output_flatten = (grad_output.contiguous() - .view(-1, embedding_size)) - grad_codebook = torch.zeros_like(codebook) - grad_codebook.index_add_(0, indices, grad_output_flatten) - - return (grad_inputs, grad_codebook) - -vq = VectorQuantization.apply -vq_st = VectorQuantizationStraightThrough.apply -__all__ = [vq, vq_st] diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py deleted file mode 100644 index f38dace2..00000000 --- a/TTS/tts/models/base_tacotron.py +++ /dev/null @@ -1,305 +0,0 @@ -import copy -from abc import abstractmethod -from typing import Dict, Tuple - -import torch -from coqpit import Coqpit -from torch import nn - -from TTS.tts.layers.losses import TacotronLoss -from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import sequence_mask -from TTS.tts.utils.speakers import SpeakerManager -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.tokenizer import TTSTokenizer -from TTS.tts.utils.visual import plot_alignment, plot_spectrogram -from TTS.utils.generic_utils import format_aux_input -from TTS.utils.io import load_fsspec -from TTS.utils.training import gradual_training_scheduler - - -class BaseTacotron(BaseTTS): - """Base class shared by Tacotron and Tacotron2""" - - def __init__( - self, - config: "TacotronConfig", - ap: "AudioProcessor", - tokenizer: "TTSTokenizer", - speaker_manager: SpeakerManager = None, - ): - super().__init__(config, ap, tokenizer, speaker_manager) - - # pass all config fields as class attributes - for key in config: - setattr(self, key, config[key]) - - # layers - self.embedding = None - self.encoder = None - self.decoder = None - self.postnet = None - - # init tensors - self.embedded_speakers = None - self.embedded_speakers_projected = None - - # global style token - if self.gst and self.use_gst: - self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim - self.gst_layer = None - - # Capacitron - if self.capacitron_vae and self.use_capacitron_vae: - self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim - self.capacitron_vae_layer = None - - # additional layers - self.decoder_backward = None - self.coarse_decoder = None - - @staticmethod - def _format_aux_input(aux_input: Dict) -> Dict: - """Set missing fields to their default values""" - if aux_input: - return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) - return None - - ############################# - # INIT FUNCTIONS - ############################# - - def _init_backward_decoder(self): - """Init the backward decoder for Forward-Backward decoding.""" - self.decoder_backward = copy.deepcopy(self.decoder) - - def _init_coarse_decoder(self): - """Init the coarse decoder for Double-Decoder Consistency.""" - self.coarse_decoder = copy.deepcopy(self.decoder) - self.coarse_decoder.r_init = self.ddc_r - self.coarse_decoder.set_r(self.ddc_r) - - ############################# - # CORE FUNCTIONS - ############################# - - @abstractmethod - def forward(self): - pass - - @abstractmethod - def inference(self): - pass - - def load_checkpoint( - self, config, checkpoint_path, eval=False, cache=False - ): # pylint: disable=unused-argument, redefined-builtin - """Load model checkpoint and set up internals. - - Args: - config (Coqpi): model configuration. - checkpoint_path (str): path to checkpoint file. - eval (bool, optional): whether to load model for evaluation. - cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. - """ - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) - self.load_state_dict(state["model"]) - # TODO: set r in run-time by taking it from the new config - if "r" in state: - # set r from the state (for compatibility with older checkpoints) - self.decoder.set_r(state["r"]) - elif "config" in state: - # set r from config used at training time (for inference) - self.decoder.set_r(state["config"]["r"]) - else: - # set r from the new config (for new-models) - self.decoder.set_r(config.r) - if eval: - self.eval() - print(f" > Model's reduction rate `r` is set to: {self.decoder.r}") - assert not self.training - - def get_criterion(self) -> nn.Module: - """Get the model criterion used in training.""" - return TacotronLoss(self.config) - - @staticmethod - def init_from_config(config: Coqpit): - """Initialize model from config.""" - from TTS.utils.audio import AudioProcessor - - ap = AudioProcessor.init_from_config(config) - tokenizer = TTSTokenizer.init_from_config(config) - speaker_manager = SpeakerManager.init_from_config(config) - return BaseTacotron(config, ap, tokenizer, speaker_manager) - - ########################## - # TEST AND LOG FUNCTIONS # - ########################## - - def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: - """Generic test run for `tts` models used by `Trainer`. - - You can override this for a different behaviour. - - Args: - assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. - - Returns: - Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. - """ - print(" | > Synthesizing test sentences.") - test_audios = {} - test_figures = {} - test_sentences = self.config.test_sentences - aux_inputs = self._get_test_aux_input() - for idx, sen in enumerate(test_sentences): - outputs_dict = synthesis( - self, - sen, - self.config, - "cuda" in str(next(self.parameters()).device), - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - use_griffin_lim=True, - do_trim_silence=False, - ) - test_audios["{}-audio".format(idx)] = outputs_dict["wav"] - test_figures["{}-prediction".format(idx)] = plot_spectrogram( - outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False - ) - test_figures["{}-alignment".format(idx)] = plot_alignment( - outputs_dict["outputs"]["alignments"], output_fig=False - ) - return {"figures": test_figures, "audios": test_audios} - - def test_log( - self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument - ) -> None: - logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) - logger.test_figures(steps, outputs["figures"]) - - ############################# - # COMMON COMPUTE FUNCTIONS - ############################# - - def compute_masks(self, text_lengths, mel_lengths): - """Compute masks against sequence paddings.""" - # B x T_in_max (boolean) - input_mask = sequence_mask(text_lengths) - output_mask = None - if mel_lengths is not None: - max_len = mel_lengths.max() - r = self.decoder.r - max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len - output_mask = sequence_mask(mel_lengths, max_len=max_len) - return input_mask, output_mask - - def _backward_pass(self, mel_specs, encoder_outputs, mask): - """Run backwards decoder""" - decoder_outputs_b, alignments_b, _ = self.decoder_backward( - encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask - ) - decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() - return decoder_outputs_b, alignments_b - - def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): - """Double Decoder Consistency""" - T = mel_specs.shape[1] - if T % self.coarse_decoder.r > 0: - padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) - mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) - decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( - encoder_outputs.detach(), mel_specs, input_mask - ) - # scale_factor = self.decoder.r_init / self.decoder.r - alignments_backward = torch.nn.functional.interpolate( - alignments_backward.transpose(1, 2), - size=alignments.shape[1], - mode="nearest", - ).transpose(1, 2) - decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) - decoder_outputs_backward = decoder_outputs_backward[:, :T, :] - return decoder_outputs_backward, alignments_backward - - ############################# - # EMBEDDING FUNCTIONS - ############################# - - def compute_gst(self, inputs, style_input, speaker_embedding=None): - """Compute global style token""" - if isinstance(style_input, dict): - # multiply each style token with a weight - query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) - if speaker_embedding is not None: - query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) - - _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) - gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) - for k_token, v_amplifier in style_input.items(): - key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) - gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) - gst_outputs = gst_outputs + gst_outputs_att * v_amplifier - elif style_input is None: - # ignore style token and return zero tensor - gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) - else: - # compute style tokens - gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable - inputs = self._concat_speaker_embedding(inputs, gst_outputs) - return inputs - - def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None): - """Capacitron Variational Autoencoder""" - ( - VAE_outputs, - posterior_distribution, - prior_distribution, - capacitron_beta, - ) = self.capacitron_vae_layer( - reference_mel_info, - text_info, - speaker_embedding, # pylint: disable=not-callable - ) - - VAE_outputs = VAE_outputs.to(inputs.device) - encoder_output = self._concat_speaker_embedding( - inputs, VAE_outputs - ) # concatenate to the output of the basic tacotron encoder - return ( - encoder_output, - posterior_distribution, - prior_distribution, - capacitron_beta, - ) - - @staticmethod - def _add_speaker_embedding(outputs, embedded_speakers): - embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) - outputs = outputs + embedded_speakers_ - return outputs - - @staticmethod - def _concat_speaker_embedding(outputs, embedded_speakers): - embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) - outputs = torch.cat([outputs, embedded_speakers_], dim=-1) - return outputs - - ############################# - # CALLBACKS - ############################# - - def on_epoch_start(self, trainer): - """Callback for setting values wrt gradual training schedule. - - Args: - trainer (TrainerTTS): TTS trainer object that is used to train this model. - """ - if self.gradual_training: - r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) - trainer.config.r = r - self.decoder.set_r(r) - if trainer.config.bidirectional_decoder: - trainer.model.decoder_backward.set_r(r) - print(f"\n > Number of output frames: {self.decoder.r}") diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index 5dd51415..dbd694aa 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -2,12 +2,12 @@ import os import random +from contextlib import contextmanager from time import time import torch import torch.nn.functional as F import torchaudio - from tqdm import tqdm from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram @@ -16,22 +16,14 @@ from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead from TTS.tts.layers.tortoise.clvp import CLVP from TTS.tts.layers.tortoise.cvvp import CVVP +from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter -from TTS.tts.layers.tortoise.vocoder import VocConf - -from TTS.tts.layers.tortoise.diffusion import ( - SpacedDiffusion, - get_named_beta_schedule, - space_timesteps, -) - from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path +from TTS.tts.layers.tortoise.vocoder import VocConf from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment -from TTS.tts.layers.tortoise.utils import MODELS_DIR, get_model_path - -from contextlib import contextmanager def pad_or_truncate(t, length): """ @@ -56,9 +48,7 @@ def load_discrete_vocoder_diffuser( Helper function to load a GaussianDiffusion instance configured for use as a vocoder. """ return SpacedDiffusion( - use_timesteps=space_timesteps( - trained_diffusion_steps, [desired_diffusion_steps] - ), + use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), model_mean_type="epsilon", model_var_type="learned_range", loss_type="mse", @@ -137,12 +127,12 @@ def do_spectrogram_diffusion( noise = torch.randn(output_shape, device=latents.device) * temperature mel = diffuser.sample_loop( - diffusion_model, - output_shape, - noise=noise, - model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, - progress=verbose - ) + diffusion_model, + output_shape, + noise=noise, + model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, + progress=verbose, + ) return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] @@ -166,9 +156,7 @@ def classify_audio_clip(clip): kernel_size=5, distribute_zero_label=False, ) - classifier.load_state_dict( - torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu")) - ) + classifier.load_state_dict(torch.load(get_model_path("classifier.pth"), map_location=torch.device("cpu"))) clip = clip.cpu().unsqueeze(0) results = F.softmax(classifier(clip), dim=-1) return results[0][0] @@ -238,9 +226,7 @@ class TextToSpeech: self.diff_checkpoint = diff_checkpoint # TODO: check if this is even needed self.models_dir = models_dir self.autoregressive_batch_size = ( - pick_best_batch_size_for_gpu() - if autoregressive_batch_size is None - else autoregressive_batch_size + pick_best_batch_size_for_gpu() if autoregressive_batch_size is None else autoregressive_batch_size ) self.enable_redaction = enable_redaction self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -274,9 +260,7 @@ class TextToSpeech: self.autoregressive.load_state_dict(torch.load(ar_path)) self.autoregressive.post_init_gpt2_config(kv_cache) - diff_path = diff_checkpoint or get_model_path( - "diffusion_decoder.pth", models_dir - ) + diff_path = diff_checkpoint or get_model_path("diffusion_decoder.pth", models_dir) self.diffusion = ( DiffusionTts( model_channels=1024, @@ -365,9 +349,7 @@ class TextToSpeech: .cpu() .eval() ) - self.cvvp.load_state_dict( - torch.load(get_model_path("cvvp.pth", self.models_dir)) - ) + self.cvvp.load_state_dict(torch.load(get_model_path("cvvp.pth", self.models_dir))) def get_conditioning_latents( self, @@ -407,11 +389,7 @@ class TextToSpeech: DURS_CONST = 102400 for ls in voice_samples: # The diffuser operates at a sample rate of 24000 (except for the latent inputs) - sample = ( - torchaudio.functional.resample(ls[0], 22050, 24000) - if original_tortoise - else ls[1] - ) + sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1] if latent_averaging_mode == 0: sample = pad_or_truncate(sample, DURS_CONST) cond_mel = wav_to_univnet_mel( @@ -426,9 +404,7 @@ class TextToSpeech: if latent_averaging_mode == 2: temp_diffusion_conds = [] for chunk in range(ceil(sample.shape[1] / DURS_CONST)): - current_sample = sample[ - :, chunk * DURS_CONST : (chunk + 1) * DURS_CONST - ] + current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST] current_sample = pad_or_truncate(current_sample, DURS_CONST) cond_mel = wav_to_univnet_mel( current_sample.to(self.device), @@ -440,9 +416,7 @@ class TextToSpeech: elif latent_averaging_mode == 2: temp_diffusion_conds.append(cond_mel) if latent_averaging_mode == 2: - diffusion_conds.append( - torch.stack(temp_diffusion_conds).mean(0) - ) + diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0)) diffusion_conds = torch.stack(diffusion_conds, dim=1) with self.temporary_cuda(self.diffusion) as diffusion: @@ -471,9 +445,7 @@ class TextToSpeech: ) ) with torch.no_grad(): - return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion( - torch.tensor([0.0]) - ) + return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) def tts_with_preset(self, text, preset="fast", **kwargs): """ @@ -521,10 +493,7 @@ class TextToSpeech: "diffusion_iterations": 50, "sampler": "ddim", }, - "fast_old": { - "num_autoregressive_samples": 96, - "diffusion_iterations": 80 - }, + "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80}, "standard": { "num_autoregressive_samples": 256, "diffusion_iterations": 200, @@ -618,9 +587,7 @@ class TextToSpeech: """ deterministic_seed = self.deterministic_state(seed=use_deterministic_seed) - text_tokens = ( - torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) - ) + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. assert ( text_tokens.shape[-1] < 400 @@ -628,12 +595,7 @@ class TextToSpeech: auto_conds = None if voice_samples is not None: - ( - auto_conditioning, - diffusion_conditioning, - auto_conds, - _, - ) = self.get_conditioning_latents( + (auto_conditioning, diffusion_conditioning, auto_conds, _,) = self.get_conditioning_latents( voice_samples, return_mels=True, latent_averaging_mode=latent_averaging_mode, @@ -650,10 +612,7 @@ class TextToSpeech: diffusion_conditioning = diffusion_conditioning.to(self.device) diffuser = load_discrete_vocoder_diffuser( - desired_diffusion_steps=diffusion_iterations, - cond_free=cond_free, - cond_free_k=cond_free_k, - sampler=sampler + desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler ) # in the case of single_sample, @@ -664,13 +623,13 @@ class TextToSpeech: samples = [] num_batches = num_autoregressive_samples // self.autoregressive_batch_size stop_mel_token = self.autoregressive.stop_mel_token - calm_token = 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + calm_token = ( + 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + ) self.autoregressive = self.autoregressive.to(self.device) if verbose: print("Generating autoregressive samples..") - with self.temporary_cuda( - self.autoregressive - ) as autoregressive, torch.autocast( + with self.temporary_cuda(self.autoregressive) as autoregressive, torch.autocast( device_type="cuda", dtype=torch.float16, enabled=half ): for b in tqdm(range(num_batches), disable=not verbose): @@ -689,9 +648,7 @@ class TextToSpeech: padding_needed = max_mel_tokens - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) samples.append(codes) - self.autoregressive_batch_size = ( - orig_batch_size # in the case of single_sample - ) + self.autoregressive_batch_size = orig_batch_size # in the case of single_sample clip_results = [] with self.temporary_cuda(self.clvp) as clvp, torch.autocast( @@ -729,9 +686,7 @@ class TextToSpeech: if cvvp_amount == 1: clip_results.append(cvvp) else: - clip_results.append( - cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount) - ) + clip_results.append(cvvp * cvvp_amount + clvp_res * (1 - cvvp_amount)) else: clip_results.append(clvp_res) clip_results = torch.cat(clip_results, dim=0) @@ -744,19 +699,14 @@ class TextToSpeech: # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. - with self.temporary_cuda( - self.autoregressive - ) as autoregressive: + with self.temporary_cuda(self.autoregressive) as autoregressive: best_latents = autoregressive( auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, torch.tensor( - [ - best_results.shape[-1] - * self.autoregressive.mel_length_compression - ], + [best_results.shape[-1] * self.autoregressive.mel_length_compression], device=text_tokens.device, ), return_latent=True, @@ -778,9 +728,7 @@ class TextToSpeech: ctokens += 1 else: ctokens = 0 - if ( - ctokens > 8 - ): # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. latents = latents[:, :k] break with self.temporary_cuda(self.diffusion) as diffusion: @@ -801,10 +749,7 @@ class TextToSpeech: return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) return clip - wav_candidates = [ - potentially_redact(wav_candidate, text) - for wav_candidate in wav_candidates - ] + wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] if len(wav_candidates) > 1: res = wav_candidates diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index dffb5a84..fd40ebb0 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -97,7 +97,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method self.mel_norm = mel_norm self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.mel_basis = None - self.normalized=normalized + self.normalized = normalized if use_mel: self._build_mel_basis()