refactor: move exists() and default() into generic_utils

This commit is contained in:
Enno Hermann 2024-11-22 22:02:26 +01:00
parent fa844e0fb7
commit b45a7a4220
7 changed files with 27 additions and 56 deletions

View File

@ -14,6 +14,8 @@ from torch import nn
from torchaudio.functional import resample from torchaudio.functional import resample
from transformers import HubertModel from transformers import HubertModel
from TTS.utils.generic_utils import exists
def round_down_nearest_multiple(num, divisor): def round_down_nearest_multiple(num, divisor):
return num // divisor * divisor return num // divisor * divisor
@ -26,14 +28,6 @@ def curtail_to_multiple(t, mult, from_left=False):
return t[..., seq_slice] return t[..., seq_slice]
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
class CustomHubert(nn.Module): class CustomHubert(nn.Module):
""" """
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

View File

@ -8,10 +8,6 @@ from TTS.tts.layers.tortoise.transformer import Transformer
from TTS.tts.layers.tortoise.xtransformers import Encoder from TTS.tts.layers.tortoise.xtransformers import Encoder
def exists(val):
return val is not None
def masked_mean(t, mask, dim=1): def masked_mean(t, mask, dim=1):
t = t.masked_fill(~mask[:, :, None], 0.0) t = t.masked_fill(~mask[:, :, None], 0.0)
return t.sum(dim=1) / mask.sum(dim=1)[..., None] return t.sum(dim=1) / mask.sum(dim=1)[..., None]

View File

@ -1,22 +1,19 @@
from typing import TypeVar, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from TTS.utils.generic_utils import exists
# helpers # helpers
_T = TypeVar("_T")
def exists(val): def cast_tuple(val: Union[tuple[_T], list[_T], _T], depth: int = 1) -> tuple[_T]:
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth=1):
if isinstance(val, list): if isinstance(val, list):
val = tuple(val) return tuple(val)
return val if isinstance(val, tuple) else (val,) * depth return val if isinstance(val, tuple) else (val,) * depth

View File

@ -1,13 +1,15 @@
import math import math
from collections import namedtuple from collections import namedtuple
from functools import partial from functools import partial
from inspect import isfunction
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import einsum, nn from torch import einsum, nn
from TTS.tts.layers.tortoise.transformer import cast_tuple, max_neg_value
from TTS.utils.generic_utils import default, exists
DEFAULT_DIM_HEAD = 64 DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
@ -25,20 +27,6 @@ LayerIntermediates = namedtuple(
# helpers # helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
class always: class always:
def __init__(self, val): def __init__(self, val):
self.val = val self.val = val
@ -63,10 +51,6 @@ class equals:
return x == self.val return x == self.val
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
def l2norm(t): def l2norm(t):
return F.normalize(t, p=2, dim=-1) return F.normalize(t, p=2, dim=-1)

View File

@ -14,10 +14,6 @@ from TTS.utils.generic_utils import is_pytorch_at_least_2_4
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def default(val, d):
return val if val is not None else d
def eval_decorator(fn): def eval_decorator(fn):
def inner(model, *args, **kwargs): def inner(model, *args, **kwargs):
was_training = model.training was_training = model.training

View File

@ -10,10 +10,7 @@ from einops.layers.torch import Rearrange
from torch import einsum, nn from torch import einsum, nn
from TTS.tts.layers.tortoise.transformer import GEGLU from TTS.tts.layers.tortoise.transformer import GEGLU
from TTS.utils.generic_utils import default, exists
def exists(val):
return val is not None
def once(fn): def once(fn):
@ -153,12 +150,6 @@ def Sequential(*mods):
return nn.Sequential(*filter(exists, mods)) return nn.Sequential(*filter(exists, mods))
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim, scale=True, dim_cond=None): def __init__(self, dim, scale=True, dim_cond=None):
super().__init__() super().__init__()

View File

@ -4,13 +4,26 @@ import importlib
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Callable, Dict, Optional, TypeVar, Union
import torch import torch
from packaging.version import Version from packaging.version import Version
from typing_extensions import TypeIs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_T = TypeVar("_T")
def exists(val: Union[_T, None]) -> TypeIs[_T]:
return val is not None
def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T:
if exists(val):
return val
return d() if callable(d) else d
def to_camel(text): def to_camel(text):
text = text.capitalize() text = text.capitalize()