mirror of https://github.com/coqui-ai/TTS.git
refactor: move exists() and default() into generic_utils
This commit is contained in:
parent
fa844e0fb7
commit
b45a7a4220
|
@ -14,6 +14,8 @@ from torch import nn
|
|||
from torchaudio.functional import resample
|
||||
from transformers import HubertModel
|
||||
|
||||
from TTS.utils.generic_utils import exists
|
||||
|
||||
|
||||
def round_down_nearest_multiple(num, divisor):
|
||||
return num // divisor * divisor
|
||||
|
@ -26,14 +28,6 @@ def curtail_to_multiple(t, mult, from_left=False):
|
|||
return t[..., seq_slice]
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
class CustomHubert(nn.Module):
|
||||
"""
|
||||
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
|
||||
|
|
|
@ -8,10 +8,6 @@ from TTS.tts.layers.tortoise.transformer import Transformer
|
|||
from TTS.tts.layers.tortoise.xtransformers import Encoder
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def masked_mean(t, mask, dim=1):
|
||||
t = t.masked_fill(~mask[:, :, None], 0.0)
|
||||
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
|
||||
|
|
|
@ -1,22 +1,19 @@
|
|||
from typing import TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.generic_utils import exists
|
||||
|
||||
# helpers
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, depth=1):
|
||||
def cast_tuple(val: Union[tuple[_T], list[_T], _T], depth: int = 1) -> tuple[_T]:
|
||||
if isinstance(val, list):
|
||||
val = tuple(val)
|
||||
return tuple(val)
|
||||
return val if isinstance(val, tuple) else (val,) * depth
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import math
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from TTS.tts.layers.tortoise.transformer import cast_tuple, max_neg_value
|
||||
from TTS.utils.generic_utils import default, exists
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
|
||||
|
@ -25,20 +27,6 @@ LayerIntermediates = namedtuple(
|
|||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def cast_tuple(val, depth):
|
||||
return val if isinstance(val, tuple) else (val,) * depth
|
||||
|
||||
|
||||
class always:
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
@ -63,10 +51,6 @@ class equals:
|
|||
return x == self.val
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p=2, dim=-1)
|
||||
|
||||
|
|
|
@ -14,10 +14,6 @@ from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
|
|
|
@ -10,10 +10,7 @@ from einops.layers.torch import Rearrange
|
|||
from torch import einsum, nn
|
||||
|
||||
from TTS.tts.layers.tortoise.transformer import GEGLU
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
from TTS.utils.generic_utils import default, exists
|
||||
|
||||
|
||||
def once(fn):
|
||||
|
@ -153,12 +150,6 @@ def Sequential(*mods):
|
|||
return nn.Sequential(*filter(exists, mods))
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, scale=True, dim_cond=None):
|
||||
super().__init__()
|
||||
|
|
|
@ -4,13 +4,26 @@ import importlib
|
|||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Callable, Dict, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def exists(val: Union[_T, None]) -> TypeIs[_T]:
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T:
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if callable(d) else d
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
|
|
Loading…
Reference in New Issue