mirror of https://github.com/coqui-ai/TTS.git
refactor(freevc): remove duplicate code
This commit is contained in:
parent
2e5f68df6a
commit
69a599d403
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn.modules.conv import Conv1d
|
||||
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||
from TTS.vocoder.models.hifigan_discriminator import LRELU_SLOPE, DiscriminatorP
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
|
@ -39,7 +39,7 @@ class DiscriminatorS(torch.nn.Module):
|
|||
feat = []
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = torch.nn.functional.leaky_relu(x, 0.1)
|
||||
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
||||
feat.append(x)
|
||||
x = self.conv_post(x)
|
||||
feat.append(x)
|
||||
|
|
|
@ -6,15 +6,15 @@ import numpy as np
|
|||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
from trainer.io import load_fsspec
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
import TTS.vc.modules.freevc.modules as modules
|
||||
from TTS.tts.layers.vits.discriminator import DiscriminatorS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||
|
@ -23,7 +23,7 @@ from TTS.vc.modules.freevc.commons import init_weights
|
|||
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
||||
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
||||
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -164,75 +164,6 @@ class Generator(torch.nn.Module):
|
|||
remove_parametrizations(l, "weight")
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||
super(DiscriminatorP, self).__init__()
|
||||
self.period = period
|
||||
self.use_spectral_norm = use_spectral_norm
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class DiscriminatorS(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(DiscriminatorS, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
|
||||
from TTS.tts.utils.helpers import convert_pad_shape
|
||||
|
||||
|
||||
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
|
||||
|
@ -96,37 +96,11 @@ def subsequent_mask(length):
|
|||
return mask
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def shift_1d(x):
|
||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||
return x
|
||||
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2, 3) * mask
|
||||
return path
|
||||
|
||||
|
||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
|
|
|
@ -5,8 +5,8 @@ from torch.nn import functional as F
|
|||
from torch.nn.utils.parametrizations import weight_norm
|
||||
from torch.nn.utils.parametrize import remove_parametrizations
|
||||
|
||||
import TTS.vc.modules.freevc.commons as commons
|
||||
from TTS.tts.layers.generic.normalization import LayerNorm2
|
||||
from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply
|
||||
from TTS.vc.modules.freevc.commons import init_weights
|
||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||
|
||||
|
@ -99,7 +99,7 @@ class WN(torch.nn.Module):
|
|||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
|
||||
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
acts = self.drop(acts)
|
||||
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
|
|
Loading…
Reference in New Issue