From b45a7a4220e21eea4825d24ba4498afb37591c64 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 22:02:26 +0100 Subject: [PATCH] refactor: move exists() and default() into generic_utils --- TTS/tts/layers/bark/hubert/kmeans_hubert.py | 10 ++-------- TTS/tts/layers/tortoise/clvp.py | 4 ---- TTS/tts/layers/tortoise/transformer.py | 17 +++++++--------- TTS/tts/layers/tortoise/xtransformers.py | 22 +++------------------ TTS/tts/layers/xtts/dvae.py | 4 ---- TTS/tts/layers/xtts/perceiver_encoder.py | 11 +---------- TTS/utils/generic_utils.py | 15 +++++++++++++- 7 files changed, 27 insertions(+), 56 deletions(-) diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index 58a614cb..ade84794 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -14,6 +14,8 @@ from torch import nn from torchaudio.functional import resample from transformers import HubertModel +from TTS.utils.generic_utils import exists + def round_down_nearest_multiple(num, divisor): return num // divisor * divisor @@ -26,14 +28,6 @@ def curtail_to_multiple(t, mult, from_left=False): return t[..., seq_slice] -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - class CustomHubert(nn.Module): """ checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert diff --git a/TTS/tts/layers/tortoise/clvp.py b/TTS/tts/layers/tortoise/clvp.py index 241dfdd4..44da1324 100644 --- a/TTS/tts/layers/tortoise/clvp.py +++ b/TTS/tts/layers/tortoise/clvp.py @@ -8,10 +8,6 @@ from TTS.tts.layers.tortoise.transformer import Transformer from TTS.tts.layers.tortoise.xtransformers import Encoder -def exists(val): - return val is not None - - def masked_mean(t, mask, dim=1): t = t.masked_fill(~mask[:, :, None], 0.0) return t.sum(dim=1) / mask.sum(dim=1)[..., None] diff --git a/TTS/tts/layers/tortoise/transformer.py b/TTS/tts/layers/tortoise/transformer.py index 6cb1bab9..ed4d79d4 100644 --- a/TTS/tts/layers/tortoise/transformer.py +++ b/TTS/tts/layers/tortoise/transformer.py @@ -1,22 +1,19 @@ +from typing import TypeVar, Union + import torch import torch.nn.functional as F from einops import rearrange from torch import nn +from TTS.utils.generic_utils import exists + # helpers +_T = TypeVar("_T") -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - -def cast_tuple(val, depth=1): +def cast_tuple(val: Union[tuple[_T], list[_T], _T], depth: int = 1) -> tuple[_T]: if isinstance(val, list): - val = tuple(val) + return tuple(val) return val if isinstance(val, tuple) else (val,) * depth diff --git a/TTS/tts/layers/tortoise/xtransformers.py b/TTS/tts/layers/tortoise/xtransformers.py index 9325b8c7..0892fee1 100644 --- a/TTS/tts/layers/tortoise/xtransformers.py +++ b/TTS/tts/layers/tortoise/xtransformers.py @@ -1,13 +1,15 @@ import math from collections import namedtuple from functools import partial -from inspect import isfunction import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import einsum, nn +from TTS.tts.layers.tortoise.transformer import cast_tuple, max_neg_value +from TTS.utils.generic_utils import default, exists + DEFAULT_DIM_HEAD = 64 Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) @@ -25,20 +27,6 @@ LayerIntermediates = namedtuple( # helpers -def exists(val): - return val is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def cast_tuple(val, depth): - return val if isinstance(val, tuple) else (val,) * depth - - class always: def __init__(self, val): self.val = val @@ -63,10 +51,6 @@ class equals: return x == self.val -def max_neg_value(tensor): - return -torch.finfo(tensor.dtype).max - - def l2norm(t): return F.normalize(t, p=2, dim=-1) diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py index 73970fb0..4f806f82 100644 --- a/TTS/tts/layers/xtts/dvae.py +++ b/TTS/tts/layers/xtts/dvae.py @@ -14,10 +14,6 @@ from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) -def default(val, d): - return val if val is not None else d - - def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training diff --git a/TTS/tts/layers/xtts/perceiver_encoder.py b/TTS/tts/layers/xtts/perceiver_encoder.py index 4b42a0e4..74770872 100644 --- a/TTS/tts/layers/xtts/perceiver_encoder.py +++ b/TTS/tts/layers/xtts/perceiver_encoder.py @@ -10,10 +10,7 @@ from einops.layers.torch import Rearrange from torch import einsum, nn from TTS.tts.layers.tortoise.transformer import GEGLU - - -def exists(val): - return val is not None +from TTS.utils.generic_utils import default, exists def once(fn): @@ -153,12 +150,6 @@ def Sequential(*mods): return nn.Sequential(*filter(exists, mods)) -def default(val, d): - if exists(val): - return val - return d() if callable(d) else d - - class RMSNorm(nn.Module): def __init__(self, dim, scale=True, dim_cond=None): super().__init__() diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index c3828224..087ae7d0 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -4,13 +4,26 @@ import importlib import logging import re from pathlib import Path -from typing import Dict, Optional +from typing import Callable, Dict, Optional, TypeVar, Union import torch from packaging.version import Version +from typing_extensions import TypeIs logger = logging.getLogger(__name__) +_T = TypeVar("_T") + + +def exists(val: Union[_T, None]) -> TypeIs[_T]: + return val is not None + + +def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T: + if exists(val): + return val + return d() if callable(d) else d + def to_camel(text): text = text.capitalize()