mirror of https://github.com/coqui-ai/TTS.git
refactor(freevc): remove duplicate code
This commit is contained in:
parent
2e5f68df6a
commit
69a599d403
|
@ -2,7 +2,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.modules.conv import Conv1d
|
from torch.nn.modules.conv import Conv1d
|
||||||
|
|
||||||
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
from TTS.vocoder.models.hifigan_discriminator import LRELU_SLOPE, DiscriminatorP
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorS(torch.nn.Module):
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
@ -39,7 +39,7 @@ class DiscriminatorS(torch.nn.Module):
|
||||||
feat = []
|
feat = []
|
||||||
for l in self.convs:
|
for l in self.convs:
|
||||||
x = l(x)
|
x = l(x)
|
||||||
x = torch.nn.functional.leaky_relu(x, 0.1)
|
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
||||||
feat.append(x)
|
feat.append(x)
|
||||||
x = self.conv_post(x)
|
x = self.conv_post(x)
|
||||||
feat.append(x)
|
feat.append(x)
|
||||||
|
|
|
@ -6,15 +6,15 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.nn.utils import spectral_norm
|
|
||||||
from torch.nn.utils.parametrizations import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
from torch.nn.utils.parametrize import remove_parametrizations
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
from trainer.io import load_fsspec
|
from trainer.io import load_fsspec
|
||||||
|
|
||||||
import TTS.vc.modules.freevc.commons as commons
|
import TTS.vc.modules.freevc.commons as commons
|
||||||
import TTS.vc.modules.freevc.modules as modules
|
import TTS.vc.modules.freevc.modules as modules
|
||||||
|
from TTS.tts.layers.vits.discriminator import DiscriminatorS
|
||||||
from TTS.tts.utils.helpers import sequence_mask
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.vc.configs.freevc_config import FreeVCConfig
|
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||||
|
@ -23,7 +23,7 @@ from TTS.vc.modules.freevc.commons import init_weights
|
||||||
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
||||||
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
||||||
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
||||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -164,75 +164,6 @@ class Generator(torch.nn.Module):
|
||||||
remove_parametrizations(l, "weight")
|
remove_parametrizations(l, "weight")
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorP(torch.nn.Module):
|
|
||||||
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
|
||||||
super(DiscriminatorP, self).__init__()
|
|
||||||
self.period = period
|
|
||||||
self.use_spectral_norm = use_spectral_norm
|
|
||||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
|
||||||
self.convs = nn.ModuleList(
|
|
||||||
[
|
|
||||||
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fmap = []
|
|
||||||
|
|
||||||
# 1d to 2d
|
|
||||||
b, c, t = x.shape
|
|
||||||
if t % self.period != 0: # pad first
|
|
||||||
n_pad = self.period - (t % self.period)
|
|
||||||
x = F.pad(x, (0, n_pad), "reflect")
|
|
||||||
t = t + n_pad
|
|
||||||
x = x.view(b, c, t // self.period, self.period)
|
|
||||||
|
|
||||||
for l in self.convs:
|
|
||||||
x = l(x)
|
|
||||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
|
||||||
fmap.append(x)
|
|
||||||
x = self.conv_post(x)
|
|
||||||
fmap.append(x)
|
|
||||||
x = torch.flatten(x, 1, -1)
|
|
||||||
|
|
||||||
return x, fmap
|
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorS(torch.nn.Module):
|
|
||||||
def __init__(self, use_spectral_norm=False):
|
|
||||||
super(DiscriminatorS, self).__init__()
|
|
||||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
|
||||||
self.convs = nn.ModuleList(
|
|
||||||
[
|
|
||||||
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
|
||||||
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
|
||||||
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
|
||||||
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
|
||||||
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
|
||||||
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
fmap = []
|
|
||||||
|
|
||||||
for l in self.convs:
|
|
||||||
x = l(x)
|
|
||||||
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
|
||||||
fmap.append(x)
|
|
||||||
x = self.conv_post(x)
|
|
||||||
fmap.append(x)
|
|
||||||
x = torch.flatten(x, 1, -1)
|
|
||||||
|
|
||||||
return x, fmap
|
|
||||||
|
|
||||||
|
|
||||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
def __init__(self, use_spectral_norm=False):
|
def __init__(self, use_spectral_norm=False):
|
||||||
super(MultiPeriodDiscriminator, self).__init__()
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask
|
from TTS.tts.utils.helpers import convert_pad_shape
|
||||||
|
|
||||||
|
|
||||||
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
|
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
|
||||||
|
@ -96,37 +96,11 @@ def subsequent_mask(length):
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
|
||||||
n_channels_int = n_channels[0]
|
|
||||||
in_act = input_a + input_b
|
|
||||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
|
||||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
|
||||||
acts = t_act * s_act
|
|
||||||
return acts
|
|
||||||
|
|
||||||
|
|
||||||
def shift_1d(x):
|
def shift_1d(x):
|
||||||
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def generate_path(duration, mask):
|
|
||||||
"""
|
|
||||||
duration: [b, 1, t_x]
|
|
||||||
mask: [b, 1, t_y, t_x]
|
|
||||||
"""
|
|
||||||
b, _, t_y, t_x = mask.shape
|
|
||||||
cum_duration = torch.cumsum(duration, -1)
|
|
||||||
|
|
||||||
cum_duration_flat = cum_duration.view(b * t_x)
|
|
||||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
|
||||||
path = path.view(b, t_x, t_y)
|
|
||||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
|
||||||
path = path.unsqueeze(1).transpose(2, 3) * mask
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||||
if isinstance(parameters, torch.Tensor):
|
if isinstance(parameters, torch.Tensor):
|
||||||
parameters = [parameters]
|
parameters = [parameters]
|
||||||
|
|
|
@ -5,8 +5,8 @@ from torch.nn import functional as F
|
||||||
from torch.nn.utils.parametrizations import weight_norm
|
from torch.nn.utils.parametrizations import weight_norm
|
||||||
from torch.nn.utils.parametrize import remove_parametrizations
|
from torch.nn.utils.parametrize import remove_parametrizations
|
||||||
|
|
||||||
import TTS.vc.modules.freevc.commons as commons
|
|
||||||
from TTS.tts.layers.generic.normalization import LayerNorm2
|
from TTS.tts.layers.generic.normalization import LayerNorm2
|
||||||
|
from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply
|
||||||
from TTS.vc.modules.freevc.commons import init_weights
|
from TTS.vc.modules.freevc.commons import init_weights
|
||||||
from TTS.vocoder.models.hifigan_generator import get_padding
|
from TTS.vocoder.models.hifigan_generator import get_padding
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class WN(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
g_l = torch.zeros_like(x_in)
|
g_l = torch.zeros_like(x_in)
|
||||||
|
|
||||||
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||||
acts = self.drop(acts)
|
acts = self.drop(acts)
|
||||||
|
|
||||||
res_skip_acts = self.res_skip_layers[i](acts)
|
res_skip_acts = self.res_skip_layers[i](acts)
|
||||||
|
|
Loading…
Reference in New Issue