refactor(freevc): remove duplicate code

This commit is contained in:
Enno Hermann 2024-11-22 12:12:50 +01:00
parent 2e5f68df6a
commit 69a599d403
4 changed files with 8 additions and 103 deletions

View File

@ -2,7 +2,7 @@ import torch
from torch import nn from torch import nn
from torch.nn.modules.conv import Conv1d from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP from TTS.vocoder.models.hifigan_discriminator import LRELU_SLOPE, DiscriminatorP
class DiscriminatorS(torch.nn.Module): class DiscriminatorS(torch.nn.Module):
@ -39,7 +39,7 @@ class DiscriminatorS(torch.nn.Module):
feat = [] feat = []
for l in self.convs: for l in self.convs:
x = l(x) x = l(x)
x = torch.nn.functional.leaky_relu(x, 0.1) x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
feat.append(x) feat.append(x)
x = self.conv_post(x) x = self.conv_post(x)
feat.append(x) feat.append(x)

View File

@ -6,15 +6,15 @@ import numpy as np
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.nn import Conv1d, Conv2d, ConvTranspose1d from torch.nn import Conv1d, ConvTranspose1d
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.utils import spectral_norm
from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec from trainer.io import load_fsspec
import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.commons as commons
import TTS.vc.modules.freevc.modules as modules import TTS.vc.modules.freevc.modules as modules
from TTS.tts.layers.vits.discriminator import DiscriminatorS
from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.vc.configs.freevc_config import FreeVCConfig from TTS.vc.configs.freevc_config import FreeVCConfig
@ -23,7 +23,7 @@ from TTS.vc.modules.freevc.commons import init_weights
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
from TTS.vc.modules.freevc.wavlm import get_wavlm from TTS.vc.modules.freevc.wavlm import get_wavlm
from TTS.vocoder.models.hifigan_generator import get_padding from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -164,75 +164,6 @@ class Generator(torch.nn.Module):
remove_parametrizations(l, "weight") remove_parametrizations(l, "weight")
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module): class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False): def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__() super(MultiPeriodDiscriminator, self).__init__()

View File

@ -3,7 +3,7 @@ import math
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask from TTS.tts.utils.helpers import convert_pad_shape
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
@ -96,37 +96,11 @@ def subsequent_mask(length):
return mask return mask
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts
def shift_1d(x): def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x return x
def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path
def clip_grad_value_(parameters, clip_value, norm_type=2): def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]

View File

@ -5,8 +5,8 @@ from torch.nn import functional as F
from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations from torch.nn.utils.parametrize import remove_parametrizations
import TTS.vc.modules.freevc.commons as commons
from TTS.tts.layers.generic.normalization import LayerNorm2 from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply
from TTS.vc.modules.freevc.commons import init_weights from TTS.vc.modules.freevc.commons import init_weights
from TTS.vocoder.models.hifigan_generator import get_padding from TTS.vocoder.models.hifigan_generator import get_padding
@ -99,7 +99,7 @@ class WN(torch.nn.Module):
else: else:
g_l = torch.zeros_like(x_in) g_l = torch.zeros_like(x_in)
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts) acts = self.drop(acts)
res_skip_acts = self.res_skip_layers[i](acts) res_skip_acts = self.res_skip_layers[i](acts)