mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'univnet' into trainer-api
This commit is contained in:
commit
7ec5c31898
|
@ -140,7 +140,11 @@ events.out*
|
|||
old_configs/*
|
||||
model_importers/*
|
||||
model_profiling/*
|
||||
<<<<<<< HEAD
|
||||
docs/source/TODO/*
|
||||
=======
|
||||
docs/*
|
||||
>>>>>>> univnet
|
||||
.noseids
|
||||
.dccache
|
||||
log.txt
|
||||
|
|
4
Makefile
4
Makefile
|
@ -6,6 +6,10 @@ help:
|
|||
|
||||
target_dirs := tests TTS notebooks
|
||||
|
||||
test_all: ## run tests and don't stop on an error.
|
||||
nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id
|
||||
./run_bash_tests.sh
|
||||
|
||||
test_all: ## run tests and don't stop on an error.
|
||||
nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id
|
||||
./run_bash_tests.sh
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnivnetConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for UnivNet vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import UnivNetConfig
|
||||
>>> config = UnivNetConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `UnivNet`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'UnivNet_discriminator`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `UnivNet_generator`.
|
||||
generator_model_params (dict): Parameters of the generator model. Defaults to
|
||||
`
|
||||
{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
`
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 32.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by univnet model. Defaults to False.
|
||||
stft_loss_params (dict):
|
||||
STFT loss parameters. Default to
|
||||
`{
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
}`
|
||||
l1_spec_loss_params (dict):
|
||||
L1 spectrogram loss parameters. Default to
|
||||
`{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "univnet"
|
||||
batch_size: int = 32
|
||||
# model specific params
|
||||
discriminator_model: str = "univnet_discriminator"
|
||||
generator_model: str = "univnet_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"in_channels": 64,
|
||||
"out_channels": 1,
|
||||
"hidden_channels": 32,
|
||||
"cond_channels": 80,
|
||||
"upsample_factors": [8, 8, 4],
|
||||
"lvc_layers_each_block": 4,
|
||||
"lvc_kernel_size": 3,
|
||||
"kpnet_hidden_channels": 64,
|
||||
"kpnet_conv_size": 3,
|
||||
"dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and univnet)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 2.5
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 1
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 0
|
||||
l1_spec_loss_weight: float = 0
|
||||
l1_spec_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
)
|
||||
|
||||
# optimizer parameters
|
||||
lr_gen: float = 1e-4 # Initial learning rate.
|
||||
lr_disc: float = 1e-4 # Initial learning rate.
|
||||
lr_scheduler_gen: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
|
||||
steps_to_start_discriminator: int = 200000
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.generator_model_params["cond_channels"] = self.audio.num_mels
|
|
@ -0,0 +1,198 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class KernelPredictor(torch.nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int):
|
||||
kpnet_
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
|
||||
l_b = conv_out_channels * conv_layers
|
||||
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
self.input_conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_conv = torch.nn.Sequential(
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True)
|
||||
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
Returns:
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
|
||||
c = self.input_conv(c)
|
||||
c = c + self.residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
|
||||
kernels = k.contiguous().view(
|
||||
batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length
|
||||
)
|
||||
bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length)
|
||||
return kernels, bias
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
upsample_ratio,
|
||||
conv_layers=4,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = conv_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.convs = torch.nn.ModuleList()
|
||||
|
||||
self.upsample = torch.nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=upsample_ratio * 2,
|
||||
stride=upsample_ratio,
|
||||
padding=upsample_ratio // 2 + upsample_ratio % 2,
|
||||
output_padding=upsample_ratio % 2,
|
||||
)
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=conv_layers,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
)
|
||||
|
||||
for i in range(conv_layers):
|
||||
padding = (3 ** i) * int((conv_kernel_size - 1) / 2)
|
||||
conv = torch.nn.Conv1d(
|
||||
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i
|
||||
)
|
||||
|
||||
self.convs.append(conv)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
in_channels = x.shape[1]
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = self.upsample(x)
|
||||
|
||||
for i in range(self.conv_layers):
|
||||
y = F.leaky_relu(x, 0.2)
|
||||
y = self.convs[i](y)
|
||||
y = F.leaky_relu(y, 0.2)
|
||||
|
||||
k = kernels[:, i, :, :, :, :]
|
||||
b = bias[:, i, :, :]
|
||||
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
|
||||
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def location_variable_convolution(x, kernel, bias, dilation, hop_size):
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
|
||||
assert in_length == (
|
||||
kernel_length * hop_size
|
||||
), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o + bias.unsqueeze(-1).unsqueeze(-1)
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
return o
|
|
@ -31,6 +31,7 @@ def setup_model(config: Coqpit):
|
|||
|
||||
|
||||
def setup_generator(c):
|
||||
""" TODO: use config object as arguments"""
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
|
@ -85,12 +86,15 @@ def setup_generator(c):
|
|||
use_weight_norm=True,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
)
|
||||
elif c.generator_model.lower() in "univnet_generator":
|
||||
model = MyModel(**c.generator_model_params)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
""" TODO: use config objekt as arguments"""
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
|
@ -144,4 +148,6 @@ def setup_discriminator(c):
|
|||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
)
|
||||
if c.discriminator_model == "univnet_discriminator":
|
||||
model = MyModel()
|
||||
return model
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class SpecDiscriminator(nn.Module):
|
||||
"""docstring for Discriminator."""
|
||||
|
||||
def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False):
|
||||
super().__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.fft_size = fft_size
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.stft = TorchSTFT(fft_size, hop_length, win_length)
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
|
||||
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, y):
|
||||
|
||||
fmap = []
|
||||
with torch.no_grad():
|
||||
y = y.squeeze(1)
|
||||
y = self.stft(y)
|
||||
y = y.unsqueeze(1)
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = F.leaky_relu(y, LRELU_SLOPE)
|
||||
fmap.append(y)
|
||||
|
||||
y = self.out(y)
|
||||
fmap.append(y)
|
||||
|
||||
return torch.flatten(y, 1, -1), fmap
|
||||
|
||||
|
||||
class MultiResSpecDiscriminator(torch.nn.Module):
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window"
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
||||
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
||||
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
scores = []
|
||||
feats = []
|
||||
for d in self.discriminators:
|
||||
score, feat = d(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
|
||||
return scores, feats
|
||||
|
||||
|
||||
class UnivnetDiscriminator(nn.Module):
|
||||
"""Univnet discriminator wrapping MPD and MSD."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mpd = MultiPeriodDiscriminator()
|
||||
self.msd = MultiResSpecDiscriminator()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input waveform.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: discriminator scores.
|
||||
List[List[Tensor]]: list of list of features from each layers of each discriminator.
|
||||
"""
|
||||
scores, feats = self.mpd(x)
|
||||
scores_, feats_ = self.msd(x)
|
||||
return scores + scores_, feats + feats_
|
|
@ -0,0 +1,145 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class UnivnetGenerator(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
cond_channels,
|
||||
upsample_factors,
|
||||
lvc_layers_each_block,
|
||||
lvc_kernel_size,
|
||||
kpnet_hidden_channels,
|
||||
kpnet_conv_size,
|
||||
dropout,
|
||||
use_weight_norm=True,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.lvc_block_nums = len(upsample_factors)
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = torch.nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
|
||||
)
|
||||
|
||||
# define residual blocks
|
||||
self.lvc_blocks = torch.nn.ModuleList()
|
||||
cond_hop_length = 1
|
||||
for n in range(self.lvc_block_nums):
|
||||
cond_hop_length = cond_hop_length * upsample_factors[n]
|
||||
lvcb = LVCBlock(
|
||||
in_channels=hidden_channels,
|
||||
cond_channels=cond_channels,
|
||||
upsample_ratio=upsample_factors[n],
|
||||
conv_layers=lvc_layers_each_block,
|
||||
conv_kernel_size=lvc_kernel_size,
|
||||
cond_hop_length=cond_hop_length,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=dropout,
|
||||
)
|
||||
self.lvc_blocks += [lvcb]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Conv1d(
|
||||
hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, c):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T)
|
||||
"""
|
||||
# random noise
|
||||
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
x = self.first_conv(x)
|
||||
|
||||
for n in range(self.lvc_block_nums):
|
||||
x = self.lvc_blocks[n](x, c)
|
||||
|
||||
# apply final layers
|
||||
for f in self.last_conv_layers:
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = f(x)
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Remove weight normalization module from all of the layers."""
|
||||
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Apply weight normalization module from all of the layers."""
|
||||
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
"""Return receptive field size."""
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
def inference(self, c=None, x=None):
|
||||
"""Perform inference.
|
||||
Args:
|
||||
c (Union[Tensor, ndarray]): Local conditioning auxiliary features (T' ,C).
|
||||
x (Union[Tensor, ndarray]): Input noise signal (T, 1).
|
||||
Returns:
|
||||
Tensor: Output tensor (T, out_channels)
|
||||
"""
|
||||
if x is not None:
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x, dtype=torch.float).to(next(self.parameters()).device)
|
||||
x = x.transpose(1, 0).unsqueeze(0)
|
||||
else:
|
||||
assert c is not None
|
||||
x = torch.randn(1, 1, len(c) * self.upsample_factor).to(next(self.parameters()).device)
|
||||
if c is not None:
|
||||
if not isinstance(c, torch.Tensor):
|
||||
c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device)
|
||||
c = c.transpose(1, 0).unsqueeze(0)
|
||||
c = torch.nn.ReplicationPad1d(self.aux_context_window)(c)
|
||||
return self.forward(c).squeeze(0).transpose(1, 0)
|
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.vocoder.configs import UnivnetConfig
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config = UnivnetConfig(
|
||||
batch_size=64,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
seq_len=8192,
|
||||
pad_short=2000,
|
||||
use_noise_augment=True,
|
||||
eval_split_size=10,
|
||||
print_step=25,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
lr_gen=1e-4,
|
||||
lr_disc=1e-4,
|
||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
trainer.fit()
|
Loading…
Reference in New Issue