mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): reuse functions/classes from tortoise
This commit is contained in:
parent
1f27f994a1
commit
66701e1e51
|
@ -6,10 +6,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module
|
||||||
class GroupNorm32(nn.GroupNorm):
|
|
||||||
def forward(self, x):
|
|
||||||
return super().forward(x.float()).type(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def conv_nd(dims, *args, **kwargs):
|
||||||
|
@ -22,24 +19,6 @@ def conv_nd(dims, *args, **kwargs):
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels):
|
|
||||||
groups = 32
|
|
||||||
if channels <= 16:
|
|
||||||
groups = 8
|
|
||||||
elif channels <= 64:
|
|
||||||
groups = 16
|
|
||||||
while channels % groups != 0:
|
|
||||||
groups = int(groups / 2)
|
|
||||||
assert groups > 2
|
|
||||||
return GroupNorm32(groups, channels)
|
|
||||||
|
|
||||||
|
|
||||||
def zero_module(module):
|
|
||||||
for p in module.parameters():
|
|
||||||
p.detach().zero_()
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
class QKVAttention(nn.Module):
|
class QKVAttention(nn.Module):
|
||||||
def __init__(self, n_heads):
|
def __init__(self, n_heads):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -9,6 +9,8 @@ from einops import rearrange, repeat
|
||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from torch import einsum, nn
|
from torch import einsum, nn
|
||||||
|
|
||||||
|
from TTS.tts.layers.tortoise.transformer import GEGLU
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
@ -194,12 +196,6 @@ class CausalConv1d(nn.Conv1d):
|
||||||
return super().forward(causal_padded_x)
|
return super().forward(causal_padded_x)
|
||||||
|
|
||||||
|
|
||||||
class GEGLU(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
x, gate = x.chunk(2, dim=-1)
|
|
||||||
return F.gelu(gate) * x
|
|
||||||
|
|
||||||
|
|
||||||
def FeedForward(dim, mult=4, causal_conv=False):
|
def FeedForward(dim, mult=4, causal_conv=False):
|
||||||
dim_inner = int(dim * mult * 2 / 3)
|
dim_inner = int(dim * mult * 2 / 3)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue