diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index 3449739f..49f7a0d0 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -2,7 +2,7 @@ import torch from torch import nn from torch.nn.modules.conv import Conv1d -from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP +from TTS.vocoder.models.hifigan_discriminator import LRELU_SLOPE, DiscriminatorP class DiscriminatorS(torch.nn.Module): @@ -39,7 +39,7 @@ class DiscriminatorS(torch.nn.Module): feat = [] for l in self.convs: x = l(x) - x = torch.nn.functional.leaky_relu(x, 0.1) + x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) feat.append(x) x = self.conv_post(x) feat.append(x) diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index e5cfdc1e..62559de5 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -6,15 +6,15 @@ import numpy as np import torch from coqpit import Coqpit from torch import nn -from torch.nn import Conv1d, Conv2d, ConvTranspose1d +from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F -from torch.nn.utils import spectral_norm from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations from trainer.io import load_fsspec import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.modules as modules +from TTS.tts.layers.vits.discriminator import DiscriminatorS from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.vc.configs.freevc_config import FreeVCConfig @@ -23,7 +23,7 @@ from TTS.vc.modules.freevc.commons import init_weights from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.wavlm import get_wavlm -from TTS.vocoder.models.hifigan_generator import get_padding +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP logger = logging.getLogger(__name__) @@ -164,75 +164,6 @@ class Generator(torch.nn.Module): remove_parametrizations(l, "weight") -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), - ] - ) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ] - ) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - class MultiPeriodDiscriminator(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(MultiPeriodDiscriminator, self).__init__() diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py index feea7f34..49889e48 100644 --- a/TTS/vc/modules/freevc/commons.py +++ b/TTS/vc/modules/freevc/commons.py @@ -3,7 +3,7 @@ import math import torch from torch.nn import functional as F -from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask +from TTS.tts.utils.helpers import convert_pad_shape def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: @@ -96,37 +96,11 @@ def subsequent_mask(length): return mask -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - def shift_1d(x): x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] return x -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - def clip_grad_value_(parameters, clip_value, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/modules/freevc/modules.py index 722444a3..ea17be24 100644 --- a/TTS/vc/modules/freevc/modules.py +++ b/TTS/vc/modules/freevc/modules.py @@ -5,8 +5,8 @@ from torch.nn import functional as F from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations -import TTS.vc.modules.freevc.commons as commons from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply from TTS.vc.modules.freevc.commons import init_weights from TTS.vocoder.models.hifigan_generator import get_padding @@ -99,7 +99,7 @@ class WN(torch.nn.Module): else: g_l = torch.zeros_like(x_in) - acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) acts = self.drop(acts) res_skip_acts = self.res_skip_layers[i](acts)