refactor: move exists() and default() into generic_utils

This commit is contained in:
Enno Hermann 2024-11-22 22:02:26 +01:00
parent fa844e0fb7
commit b45a7a4220
7 changed files with 27 additions and 56 deletions

View File

@ -14,6 +14,8 @@ from torch import nn
from torchaudio.functional import resample
from transformers import HubertModel
from TTS.utils.generic_utils import exists
def round_down_nearest_multiple(num, divisor):
return num // divisor * divisor
@ -26,14 +28,6 @@ def curtail_to_multiple(t, mult, from_left=False):
return t[..., seq_slice]
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
class CustomHubert(nn.Module):
"""
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

View File

@ -8,10 +8,6 @@ from TTS.tts.layers.tortoise.transformer import Transformer
from TTS.tts.layers.tortoise.xtransformers import Encoder
def exists(val):
return val is not None
def masked_mean(t, mask, dim=1):
t = t.masked_fill(~mask[:, :, None], 0.0)
return t.sum(dim=1) / mask.sum(dim=1)[..., None]

View File

@ -1,22 +1,19 @@
from typing import TypeVar, Union
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from TTS.utils.generic_utils import exists
# helpers
_T = TypeVar("_T")
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth=1):
def cast_tuple(val: Union[tuple[_T], list[_T], _T], depth: int = 1) -> tuple[_T]:
if isinstance(val, list):
val = tuple(val)
return tuple(val)
return val if isinstance(val, tuple) else (val,) * depth

View File

@ -1,13 +1,15 @@
import math
from collections import namedtuple
from functools import partial
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
from TTS.tts.layers.tortoise.transformer import cast_tuple, max_neg_value
from TTS.utils.generic_utils import default, exists
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
@ -25,20 +27,6 @@ LayerIntermediates = namedtuple(
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
class always:
def __init__(self, val):
self.val = val
@ -63,10 +51,6 @@ class equals:
return x == self.val
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
def l2norm(t):
return F.normalize(t, p=2, dim=-1)

View File

@ -14,10 +14,6 @@ from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__)
def default(val, d):
return val if val is not None else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training

View File

@ -10,10 +10,7 @@ from einops.layers.torch import Rearrange
from torch import einsum, nn
from TTS.tts.layers.tortoise.transformer import GEGLU
def exists(val):
return val is not None
from TTS.utils.generic_utils import default, exists
def once(fn):
@ -153,12 +150,6 @@ def Sequential(*mods):
return nn.Sequential(*filter(exists, mods))
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
class RMSNorm(nn.Module):
def __init__(self, dim, scale=True, dim_cond=None):
super().__init__()

View File

@ -4,13 +4,26 @@ import importlib
import logging
import re
from pathlib import Path
from typing import Dict, Optional
from typing import Callable, Dict, Optional, TypeVar, Union
import torch
from packaging.version import Version
from typing_extensions import TypeIs
logger = logging.getLogger(__name__)
_T = TypeVar("_T")
def exists(val: Union[_T, None]) -> TypeIs[_T]:
return val is not None
def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T:
if exists(val):
return val
return d() if callable(d) else d
def to_camel(text):
text = text.capitalize()