refactor(xtts): reuse functions/classes from tortoise

This commit is contained in:
Enno Hermann 2024-11-21 12:21:38 +01:00
parent 1f27f994a1
commit 66701e1e51
2 changed files with 3 additions and 28 deletions

View File

@ -6,10 +6,7 @@ import torch
from torch import nn
from torch.nn import functional as F
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module
def conv_nd(dims, *args, **kwargs):
@ -22,24 +19,6 @@ def conv_nd(dims, *args, **kwargs):
raise ValueError(f"unsupported dimensions: {dims}")
def normalization(channels):
groups = 32
if channels <= 16:
groups = 8
elif channels <= 64:
groups = 16
while channels % groups != 0:
groups = int(groups / 2)
assert groups > 2
return GroupNorm32(groups, channels)
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
class QKVAttention(nn.Module):
def __init__(self, n_heads):
super().__init__()

View File

@ -9,6 +9,8 @@ from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import einsum, nn
from TTS.tts.layers.tortoise.transformer import GEGLU
def exists(val):
return val is not None
@ -194,12 +196,6 @@ class CausalConv1d(nn.Conv1d):
return super().forward(causal_padded_x)
class GEGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.gelu(gate) * x
def FeedForward(dim, mult=4, causal_conv=False):
dim_inner = int(dim * mult * 2 / 3)