From 9f80e043e4746982371125a293d3b7427b043536 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Mon, 24 Jun 2024 13:28:14 +0200 Subject: [PATCH] refactor(freevc): use existing layernorm --- TTS/vc/modules/freevc/modules.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/modules/freevc/modules.py index 9bb54990..da5bef8a 100644 --- a/TTS/vc/modules/freevc/modules.py +++ b/TTS/vc/modules/freevc/modules.py @@ -6,26 +6,12 @@ from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations import TTS.vc.modules.freevc.commons as commons +from TTS.tts.layers.generic.normalization import LayerNorm2 from TTS.vc.modules.freevc.commons import get_padding, init_weights LRELU_SLOPE = 0.1 -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-5): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = nn.Parameter(torch.ones(channels)) - self.beta = nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - x = x.transpose(1, -1) - x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) - return x.transpose(1, -1) - - class ConvReluNorm(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): super().__init__() @@ -40,11 +26,11 @@ class ConvReluNorm(nn.Module): self.conv_layers = nn.ModuleList() self.norm_layers = nn.ModuleList() self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) + self.norm_layers.append(LayerNorm2(hidden_channels)) self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) for _ in range(n_layers - 1): self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) + self.norm_layers.append(LayerNorm2(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) self.proj.weight.data.zero_() self.proj.bias.data.zero_() @@ -83,8 +69,8 @@ class DDSConv(nn.Module): nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) ) self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) - self.norms_1.append(LayerNorm(channels)) - self.norms_2.append(LayerNorm(channels)) + self.norms_1.append(LayerNorm2(channels)) + self.norms_2.append(LayerNorm2(channels)) def forward(self, x, x_mask, g=None): if g is not None: