mirror of https://github.com/coqui-ai/TTS.git
refactor(xtts): reuse functions/classes from tortoise
This commit is contained in:
parent
1f27f994a1
commit
66701e1e51
|
@ -6,10 +6,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
|
@ -22,24 +19,6 @@ def conv_nd(dims, *args, **kwargs):
|
|||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
groups = 32
|
||||
if channels <= 16:
|
||||
groups = 8
|
||||
elif channels <= 64:
|
||||
groups = 16
|
||||
while channels % groups != 0:
|
||||
groups = int(groups / 2)
|
||||
assert groups > 2
|
||||
return GroupNorm32(groups, channels)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
|
|
|
@ -9,6 +9,8 @@ from einops import rearrange, repeat
|
|||
from einops.layers.torch import Rearrange
|
||||
from torch import einsum, nn
|
||||
|
||||
from TTS.tts.layers.tortoise.transformer import GEGLU
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
@ -194,12 +196,6 @@ class CausalConv1d(nn.Conv1d):
|
|||
return super().forward(causal_padded_x)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.gelu(gate) * x
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, causal_conv=False):
|
||||
dim_inner = int(dim * mult * 2 / 3)
|
||||
|
||||
|
|
Loading…
Reference in New Issue