diff --git a/TTS/tts/layers/xtts/latent_encoder.py b/TTS/tts/layers/xtts/latent_encoder.py index f9d62a36..7d385ec4 100644 --- a/TTS/tts/layers/xtts/latent_encoder.py +++ b/TTS/tts/layers/xtts/latent_encoder.py @@ -6,10 +6,7 @@ import torch from torch import nn from torch.nn import functional as F - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) +from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module def conv_nd(dims, *args, **kwargs): @@ -22,24 +19,6 @@ def conv_nd(dims, *args, **kwargs): raise ValueError(f"unsupported dimensions: {dims}") -def normalization(channels): - groups = 32 - if channels <= 16: - groups = 8 - elif channels <= 64: - groups = 16 - while channels % groups != 0: - groups = int(groups / 2) - assert groups > 2 - return GroupNorm32(groups, channels) - - -def zero_module(module): - for p in module.parameters(): - p.detach().zero_() - return module - - class QKVAttention(nn.Module): def __init__(self, n_heads): super().__init__() diff --git a/TTS/tts/layers/xtts/perceiver_encoder.py b/TTS/tts/layers/xtts/perceiver_encoder.py index f4b6e841..4b42a0e4 100644 --- a/TTS/tts/layers/xtts/perceiver_encoder.py +++ b/TTS/tts/layers/xtts/perceiver_encoder.py @@ -9,6 +9,8 @@ from einops import rearrange, repeat from einops.layers.torch import Rearrange from torch import einsum, nn +from TTS.tts.layers.tortoise.transformer import GEGLU + def exists(val): return val is not None @@ -194,12 +196,6 @@ class CausalConv1d(nn.Conv1d): return super().forward(causal_padded_x) -class GEGLU(nn.Module): - def forward(self, x): - x, gate = x.chunk(2, dim=-1) - return F.gelu(gate) * x - - def FeedForward(dim, mult=4, causal_conv=False): dim_inner = int(dim * mult * 2 / 3)