mirror of https://github.com/coqui-ai/TTS.git
refactor: move exists() and default() into generic_utils
This commit is contained in:
parent
fa844e0fb7
commit
b45a7a4220
|
@ -14,6 +14,8 @@ from torch import nn
|
||||||
from torchaudio.functional import resample
|
from torchaudio.functional import resample
|
||||||
from transformers import HubertModel
|
from transformers import HubertModel
|
||||||
|
|
||||||
|
from TTS.utils.generic_utils import exists
|
||||||
|
|
||||||
|
|
||||||
def round_down_nearest_multiple(num, divisor):
|
def round_down_nearest_multiple(num, divisor):
|
||||||
return num // divisor * divisor
|
return num // divisor * divisor
|
||||||
|
@ -26,14 +28,6 @@ def curtail_to_multiple(t, mult, from_left=False):
|
||||||
return t[..., seq_slice]
|
return t[..., seq_slice]
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
return val if exists(val) else d
|
|
||||||
|
|
||||||
|
|
||||||
class CustomHubert(nn.Module):
|
class CustomHubert(nn.Module):
|
||||||
"""
|
"""
|
||||||
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
|
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
|
||||||
|
|
|
@ -8,10 +8,6 @@ from TTS.tts.layers.tortoise.transformer import Transformer
|
||||||
from TTS.tts.layers.tortoise.xtransformers import Encoder
|
from TTS.tts.layers.tortoise.xtransformers import Encoder
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def masked_mean(t, mask, dim=1):
|
def masked_mean(t, mask, dim=1):
|
||||||
t = t.masked_fill(~mask[:, :, None], 0.0)
|
t = t.masked_fill(~mask[:, :, None], 0.0)
|
||||||
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
|
return t.sum(dim=1) / mask.sum(dim=1)[..., None]
|
||||||
|
|
|
@ -1,22 +1,19 @@
|
||||||
|
from typing import TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.utils.generic_utils import exists
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def cast_tuple(val: Union[tuple[_T], list[_T], _T], depth: int = 1) -> tuple[_T]:
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
return val if exists(val) else d
|
|
||||||
|
|
||||||
|
|
||||||
def cast_tuple(val, depth=1):
|
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
val = tuple(val)
|
return tuple(val)
|
||||||
return val if isinstance(val, tuple) else (val,) * depth
|
return val if isinstance(val, tuple) else (val,) * depth
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
import math
|
import math
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import isfunction
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from torch import einsum, nn
|
from torch import einsum, nn
|
||||||
|
|
||||||
|
from TTS.tts.layers.tortoise.transformer import cast_tuple, max_neg_value
|
||||||
|
from TTS.utils.generic_utils import default, exists
|
||||||
|
|
||||||
DEFAULT_DIM_HEAD = 64
|
DEFAULT_DIM_HEAD = 64
|
||||||
|
|
||||||
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
|
Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"])
|
||||||
|
@ -25,20 +27,6 @@ LayerIntermediates = namedtuple(
|
||||||
# helpers
|
# helpers
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
if exists(val):
|
|
||||||
return val
|
|
||||||
return d() if isfunction(d) else d
|
|
||||||
|
|
||||||
|
|
||||||
def cast_tuple(val, depth):
|
|
||||||
return val if isinstance(val, tuple) else (val,) * depth
|
|
||||||
|
|
||||||
|
|
||||||
class always:
|
class always:
|
||||||
def __init__(self, val):
|
def __init__(self, val):
|
||||||
self.val = val
|
self.val = val
|
||||||
|
@ -63,10 +51,6 @@ class equals:
|
||||||
return x == self.val
|
return x == self.val
|
||||||
|
|
||||||
|
|
||||||
def max_neg_value(tensor):
|
|
||||||
return -torch.finfo(tensor.dtype).max
|
|
||||||
|
|
||||||
|
|
||||||
def l2norm(t):
|
def l2norm(t):
|
||||||
return F.normalize(t, p=2, dim=-1)
|
return F.normalize(t, p=2, dim=-1)
|
||||||
|
|
||||||
|
|
|
@ -14,10 +14,6 @@ from TTS.utils.generic_utils import is_pytorch_at_least_2_4
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
return val if val is not None else d
|
|
||||||
|
|
||||||
|
|
||||||
def eval_decorator(fn):
|
def eval_decorator(fn):
|
||||||
def inner(model, *args, **kwargs):
|
def inner(model, *args, **kwargs):
|
||||||
was_training = model.training
|
was_training = model.training
|
||||||
|
|
|
@ -10,10 +10,7 @@ from einops.layers.torch import Rearrange
|
||||||
from torch import einsum, nn
|
from torch import einsum, nn
|
||||||
|
|
||||||
from TTS.tts.layers.tortoise.transformer import GEGLU
|
from TTS.tts.layers.tortoise.transformer import GEGLU
|
||||||
|
from TTS.utils.generic_utils import default, exists
|
||||||
|
|
||||||
def exists(val):
|
|
||||||
return val is not None
|
|
||||||
|
|
||||||
|
|
||||||
def once(fn):
|
def once(fn):
|
||||||
|
@ -153,12 +150,6 @@ def Sequential(*mods):
|
||||||
return nn.Sequential(*filter(exists, mods))
|
return nn.Sequential(*filter(exists, mods))
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
|
||||||
if exists(val):
|
|
||||||
return val
|
|
||||||
return d() if callable(d) else d
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
def __init__(self, dim, scale=True, dim_cond=None):
|
def __init__(self, dim, scale=True, dim_cond=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -4,13 +4,26 @@ import importlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Callable, Dict, Optional, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val: Union[_T, None]) -> TypeIs[_T]:
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T:
|
||||||
|
if exists(val):
|
||||||
|
return val
|
||||||
|
return d() if callable(d) else d
|
||||||
|
|
||||||
|
|
||||||
def to_camel(text):
|
def to_camel(text):
|
||||||
text = text.capitalize()
|
text = text.capitalize()
|
||||||
|
|
Loading…
Reference in New Issue