refactor(freevc): remove duplicate code

This commit is contained in:
Enno Hermann 2024-11-22 12:12:50 +01:00
parent 2e5f68df6a
commit 69a599d403
4 changed files with 8 additions and 103 deletions

View File

@ -2,7 +2,7 @@ import torch
from torch import nn
from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
from TTS.vocoder.models.hifigan_discriminator import LRELU_SLOPE, DiscriminatorP
class DiscriminatorS(torch.nn.Module):
@ -39,7 +39,7 @@ class DiscriminatorS(torch.nn.Module):
feat = []
for l in self.convs:
x = l(x)
x = torch.nn.functional.leaky_relu(x, 0.1)
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
feat.append(x)
x = self.conv_post(x)
feat.append(x)

View File

@ -6,15 +6,15 @@ import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec
import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules
from TTS.tts.layers.vits.discriminator import DiscriminatorS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.vc.configs.freevc_config import FreeVCConfig
@ -23,7 +23,7 @@ from TTS.vc.modules.freevc.commons import init_weights
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm
from TTS.vocoder.models.hifigan_generator import get_padding
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
logger = logging.getLogger(__name__)
@ -164,75 +164,6 @@ class Generator(torch.nn.Module):
remove_parametrizations(l, "weight")
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()

View File

@ -3,7 +3,7 @@ import math
import torch
from torch.nn import functional as F
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
from TTS.tts.utils.helpers import convert_pad_shape
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
@ -96,37 +96,11 @@ def subsequent_mask(length):
return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]

View File

@ -5,8 +5,8 @@ from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons
from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply
from TTS.vc.modules.freevc.commons import init_weights
from TTS.vocoder.models.hifigan_generator import get_padding
@ -99,7 +99,7 @@ class WN(torch.nn.Module):
else:
g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts)