mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'pwgan' into dev
This commit is contained in:
commit
3bc38517aa
|
@ -0,0 +1,143 @@
|
|||
{
|
||||
"run_name": "pwgan",
|
||||
"run_description": "parallel-wavegan training",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 0, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 1.0, // scaler value appplied after log transform of spectrogram.
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
// "distributed":{
|
||||
// "backend": "nccl",
|
||||
// "url": "tcp:\/\/localhost:54321"
|
||||
// },
|
||||
|
||||
// MODEL PARAMETERS
|
||||
"use_pqmf": true,
|
||||
|
||||
// LOSS PARAMETERS
|
||||
"use_stft_loss": true,
|
||||
"use_subband_stft_loss": false, // USE ONLY WITH MULTIBAND MODELS
|
||||
"use_mse_gan_loss": true,
|
||||
"use_hinge_gan_loss": false,
|
||||
"use_feat_match_loss": false, // use only with melgan discriminators
|
||||
|
||||
// loss weights
|
||||
"stft_loss_weight": 0.5,
|
||||
"subband_stft_loss_weight": 0.5,
|
||||
"mse_G_loss_weight": 2.5,
|
||||
"hinge_G_loss_weight": 2.5,
|
||||
"feat_match_loss_weight": 25,
|
||||
|
||||
// multiscale stft loss parameters
|
||||
"stft_loss_params": {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
},
|
||||
|
||||
// subband multiscale stft loss parameters
|
||||
"subband_stft_loss_params":{
|
||||
"n_ffts": [384, 683, 171],
|
||||
"hop_lengths": [30, 60, 10],
|
||||
"win_lengths": [150, 300, 60]
|
||||
},
|
||||
|
||||
"target_loss": "avg_G_loss", // loss value to pick the best model to save after each epoch
|
||||
|
||||
// DISCRIMINATOR
|
||||
"discriminator_model": "parallel_wavegan_discriminator",
|
||||
"discriminator_model_params":{
|
||||
"num_layers": 10
|
||||
},
|
||||
"steps_to_start_discriminator": 200000, // steps required to start GAN trainining.1
|
||||
|
||||
// GENERATOR
|
||||
"generator_model": "parallel_wavegan_generator",
|
||||
"generator_model_params": {
|
||||
"upsample_factors":[4, 4, 4, 4],
|
||||
"stacks": 3,
|
||||
"num_res_blocks": 30
|
||||
},
|
||||
|
||||
// DATASET
|
||||
"data_path": "/home/erogol/Data/LJSpeech-1.1/wavs/",
|
||||
"feature_path": null,
|
||||
"seq_len": 25600,
|
||||
"pad_short": 2000,
|
||||
"conv_pad": 0,
|
||||
"use_noise_augment": false,
|
||||
"use_cache": true,
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 6, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"epochs": 10000, // total number of epochs to train.
|
||||
"wd": 0.0, // Weight decay weight.
|
||||
"gen_clip_grad": -1, // Generator gradient clipping threshold. Apply gradient clipping if > 0
|
||||
"disc_clip_grad": -1, // Discriminator gradient clipping threshold.
|
||||
"lr_scheduler_gen": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
||||
"lr_scheduler_gen_params": {
|
||||
"gamma": 0.5,
|
||||
"milestones": [100000, 200000, 300000, 400000, 500000, 600000]
|
||||
},
|
||||
"lr_scheduler_disc": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
||||
"lr_scheduler_disc_params": {
|
||||
"gamma": 0.5,
|
||||
"milestones": [100000, 200000, 300000, 400000, 500000, 600000]
|
||||
},
|
||||
"lr_gen": 1e-4, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_disc": 1e-4,
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log traning on console.
|
||||
"print_eval": false, // If True, it prints loss values for each step in eval run.
|
||||
"save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
// DATA LOADING
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"eval_split_size": 10,
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/"
|
||||
}
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
"""Residual block module in WaveNet."""
|
||||
|
||||
def __init__(self,
|
||||
kernel_size=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
use_causal_conv=False
|
||||
):
|
||||
super(ResidualBlock, self).__init__()
|
||||
self.dropout = dropout
|
||||
# no future time stamps available
|
||||
if use_causal_conv:
|
||||
padding = (kernel_size - 1) * dilation
|
||||
else:
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.use_causal_conv = use_causal_conv
|
||||
|
||||
# dilation conv
|
||||
self.conv = torch.nn.Conv1d(res_channels, gate_channels, kernel_size,
|
||||
padding=padding, dilation=dilation, bias=bias)
|
||||
|
||||
# local conditioning
|
||||
if aux_channels > 0:
|
||||
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
|
||||
else:
|
||||
self.conv1x1_aux = None
|
||||
|
||||
# conv output is split into two groups
|
||||
gate_out_channels = gate_channels // 2
|
||||
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
|
||||
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""
|
||||
x: B x D_res x T
|
||||
c: B x D_aux x T
|
||||
"""
|
||||
residual = x
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.conv(x)
|
||||
|
||||
# remove future time steps if use_causal_conv conv
|
||||
x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
|
||||
|
||||
# split into two part for gated activation
|
||||
splitdim = 1
|
||||
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
||||
|
||||
# local conditioning
|
||||
if c is not None:
|
||||
assert self.conv1x1_aux is not None
|
||||
c = self.conv1x1_aux(c)
|
||||
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
||||
xa, xb = xa + ca, xb + cb
|
||||
|
||||
x = torch.tanh(xa) * torch.sigmoid(xb)
|
||||
|
||||
# for skip connection
|
||||
s = self.conv1x1_skip(x)
|
||||
|
||||
# for residual connection
|
||||
x = (self.conv1x1_out(x) + residual) * (0.5 ** 2)
|
||||
|
||||
return x, s
|
|
@ -0,0 +1,100 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Stretch2d(torch.nn.Module):
|
||||
def __init__(self, x_scale, y_scale, mode="nearest"):
|
||||
super(Stretch2d, self).__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x (Tensor): Input tensor (B, C, F, T).
|
||||
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
||||
"""
|
||||
return F.interpolate(
|
||||
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
|
||||
|
||||
|
||||
class UpsampleNetwork(torch.nn.Module):
|
||||
def __init__(self,
|
||||
upsample_factors,
|
||||
nonlinear_activation=None,
|
||||
nonlinear_activation_params={},
|
||||
interpolate_mode="nearest",
|
||||
freq_axis_kernel_size=1,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super(UpsampleNetwork, self).__init__()
|
||||
self.use_causal_conv = use_causal_conv
|
||||
self.up_layers = torch.nn.ModuleList()
|
||||
for scale in upsample_factors:
|
||||
# interpolation layer
|
||||
stretch = Stretch2d(scale, 1, interpolate_mode)
|
||||
self.up_layers += [stretch]
|
||||
|
||||
# conv layer
|
||||
assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
|
||||
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
||||
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
||||
if use_causal_conv:
|
||||
padding = (freq_axis_padding, scale * 2)
|
||||
else:
|
||||
padding = (freq_axis_padding, scale)
|
||||
conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
||||
self.up_layers += [conv]
|
||||
|
||||
# nonlinear
|
||||
if nonlinear_activation is not None:
|
||||
nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||
self.up_layers += [nonlinear]
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c : (B, C, T_in).
|
||||
Tensor: (B, C, T_upsample)
|
||||
"""
|
||||
c = c.unsqueeze(1) # (B, 1, C, T)
|
||||
for f in self.up_layers:
|
||||
c = f(c)
|
||||
return c.squeeze(1) # (B, C, T')
|
||||
|
||||
|
||||
class ConvUpsample(torch.nn.Module):
|
||||
def __init__(self,
|
||||
upsample_factors,
|
||||
nonlinear_activation=None,
|
||||
nonlinear_activation_params={},
|
||||
interpolate_mode="nearest",
|
||||
freq_axis_kernel_size=1,
|
||||
aux_channels=80,
|
||||
aux_context_window=0,
|
||||
use_causal_conv=False
|
||||
):
|
||||
super(ConvUpsample, self).__init__()
|
||||
self.aux_context_window = aux_context_window
|
||||
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
||||
# To capture wide-context information in conditional features
|
||||
kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
||||
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
||||
self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
|
||||
self.upsample = UpsampleNetwork(
|
||||
upsample_factors=upsample_factors,
|
||||
nonlinear_activation=nonlinear_activation,
|
||||
nonlinear_activation_params=nonlinear_activation_params,
|
||||
interpolate_mode=interpolate_mode,
|
||||
freq_axis_kernel_size=freq_axis_kernel_size,
|
||||
use_causal_conv=use_causal_conv,
|
||||
)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c : (B, C, T_in).
|
||||
Tensor: (B, C, T_upsampled),
|
||||
"""
|
||||
c_ = self.conv_in(c)
|
||||
c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
|
||||
return self.upsample(c)
|
|
@ -0,0 +1,192 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
|
||||
|
||||
class ParallelWaveganDiscriminator(nn.Module):
|
||||
"""PWGAN discriminator as in https://arxiv.org/abs/1910.11480.
|
||||
It classifies each audio window real/fake and returns a sequence
|
||||
of predictions.
|
||||
It is a stack of convolutional blocks with dilation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=10,
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
):
|
||||
super(ParallelWaveganDiscriminator, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size."
|
||||
assert dilation_factor > 0, " [!] dilation factor must be > 0."
|
||||
self.conv_layers = nn.ModuleList()
|
||||
conv_in_channels = in_channels
|
||||
for i in range(num_layers - 1):
|
||||
if i == 0:
|
||||
dilation = 1
|
||||
else:
|
||||
dilation = i if dilation_factor == 1 else dilation_factor ** i
|
||||
conv_in_channels = conv_channels
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
conv_layer = [
|
||||
nn.Conv1d(conv_in_channels, conv_channels,
|
||||
kernel_size=kernel_size, padding=padding,
|
||||
dilation=dilation, bias=bias),
|
||||
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
|
||||
]
|
||||
self.conv_layers += conv_layer
|
||||
padding = (kernel_size - 1) // 2
|
||||
last_conv_layer = nn.Conv1d(
|
||||
conv_in_channels, out_channels,
|
||||
kernel_size=kernel_size, padding=padding, bias=bias)
|
||||
self.conv_layers += [last_conv_layer]
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x : (B, 1, T).
|
||||
Returns:
|
||||
Tensor: (B, 1, T)
|
||||
"""
|
||||
for f in self.conv_layers:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
|
||||
class ResidualParallelWaveganDiscriminator(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
):
|
||||
super(ResidualParallelWaveganDiscriminator, self).__init__()
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_layers = num_layers
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.res_factor = math.sqrt(1.0 / num_layers)
|
||||
|
||||
# check the number of num_layers and stacks
|
||||
assert num_layers % stacks == 0
|
||||
layers_per_stack = num_layers // stacks
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = nn.Sequential(
|
||||
nn.Conv1d(in_channels,
|
||||
res_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True),
|
||||
getattr(nn, nonlinear_activation)(inplace=True,
|
||||
**nonlinear_activation_params),
|
||||
)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for layer in range(num_layers):
|
||||
dilation = 2 ** (layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
res_channels=res_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=-1,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
use_causal_conv=False,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = nn.ModuleList([
|
||||
getattr(nn, nonlinear_activation)(inplace=True,
|
||||
**nonlinear_activation_params),
|
||||
nn.Conv1d(skip_channels,
|
||||
skip_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True),
|
||||
getattr(nn, nonlinear_activation)(inplace=True,
|
||||
**nonlinear_activation_params),
|
||||
nn.Conv1d(skip_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
bias=True),
|
||||
])
|
||||
|
||||
# apply weight norm
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (B, 1, T).
|
||||
"""
|
||||
x = self.first_conv(x)
|
||||
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, None)
|
||||
skips += h
|
||||
skips *= self.res_factor
|
||||
|
||||
# apply final layers
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
print(f"Weight norm is removed from {m}.")
|
||||
nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
|
@ -0,0 +1,162 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
|
||||
from TTS.vocoder.layers.upsample import ConvUpsample
|
||||
|
||||
|
||||
class ParallelWaveganGenerator(torch.nn.Module):
|
||||
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
||||
It is similar to WaveNet with no causal convolution.
|
||||
It is conditioned on an aux feature (spectrogram) to generate
|
||||
an output waveform from an input noise.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
aux_context_window=2,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
use_causal_conv=False,
|
||||
upsample_conditional_features=True,
|
||||
upsample_net="ConvInUpsampleNetwork",
|
||||
upsample_factors=[4, 4, 4, 4],
|
||||
inference_padding=2):
|
||||
|
||||
super(ParallelWaveganGenerator, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.aux_channels = aux_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.upsample_factors = upsample_factors
|
||||
self.upsample_scale = np.prod(upsample_factors)
|
||||
self.inference_padding = inference_padding
|
||||
|
||||
# check the number of layers and stacks
|
||||
assert num_res_blocks % stacks == 0
|
||||
layers_per_stack = num_res_blocks // stacks
|
||||
|
||||
# define first convolution
|
||||
self.first_conv = torch.nn.Conv1d(in_channels,
|
||||
res_channels,
|
||||
kernel_size=1,
|
||||
bias=True)
|
||||
|
||||
# define conv + upsampling network
|
||||
self.upsample_net = ConvUpsample(upsample_factors=upsample_factors)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = torch.nn.ModuleList()
|
||||
for layer in range(num_res_blocks):
|
||||
dilation = 2**(layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
res_channels=res_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=aux_channels,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias,
|
||||
)
|
||||
self.conv_layers += [conv]
|
||||
|
||||
# define output layers
|
||||
self.last_conv_layers = torch.nn.ModuleList([
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv1d(skip_channels,
|
||||
skip_channels,
|
||||
kernel_size=1,
|
||||
bias=True),
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv1d(skip_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=True),
|
||||
])
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
c: (B, C ,T').
|
||||
o: Output tensor (B, out_channels, T)
|
||||
"""
|
||||
# random noise
|
||||
x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
|
||||
# perform upsampling
|
||||
if c is not None and self.upsample_net is not None:
|
||||
c = self.upsample_net(c)
|
||||
assert c.shape[-1] == x.shape[
|
||||
-1], f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}"
|
||||
|
||||
# encode to hidden representation
|
||||
x = self.first_conv(x)
|
||||
skips = 0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, c)
|
||||
skips += h
|
||||
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
||||
|
||||
# apply final layers
|
||||
x = skips
|
||||
for f in self.last_conv_layers:
|
||||
x = f(x)
|
||||
|
||||
return x
|
||||
|
||||
def inference(self, c):
|
||||
c = c.to(self.first_conv.weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
c, (self.inference_padding, self.inference_padding), 'replicate')
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(m):
|
||||
try:
|
||||
# print(f"Weight norm is removed from {m}.")
|
||||
torch.nn.utils.remove_weight_norm(m)
|
||||
except ValueError: # this module didn't have weight norm
|
||||
return
|
||||
|
||||
self.apply(_remove_weight_norm)
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(m):
|
||||
if isinstance(m, torch.nn.Conv1d) or isinstance(
|
||||
m, torch.nn.Conv2d):
|
||||
torch.nn.utils.weight_norm(m)
|
||||
# print(f"Weight norm is applied to {m}.")
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers,
|
||||
stacks,
|
||||
kernel_size,
|
||||
dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
return self._get_receptive_field_size(self.layers, self.stacks,
|
||||
self.kernel_size)
|
|
@ -0,0 +1,41 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.models.parallel_wavegan_discriminator import ParallelWaveganDiscriminator, ResidualParallelWaveganDiscriminator
|
||||
|
||||
|
||||
def test_pwgan_disciminator():
|
||||
model = ParallelWaveganDiscriminator(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=10,
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True)
|
||||
dummy_x = torch.rand((4, 1, 64 * 256))
|
||||
output = model(dummy_x)
|
||||
assert np.all(output.shape == (4, 1, 64 * 256))
|
||||
model.remove_weight_norm()
|
||||
|
||||
|
||||
def test_redisual_pwgan_disciminator():
|
||||
model = ResidualParallelWaveganDiscriminator(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2})
|
||||
dummy_x = torch.rand((4, 1, 64 * 256))
|
||||
output = model(dummy_x)
|
||||
assert np.all(output.shape == (4, 1, 64 * 256))
|
||||
model.remove_weight_norm()
|
|
@ -0,0 +1,30 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator
|
||||
|
||||
|
||||
def test_pwgan_generator():
|
||||
model = ParallelWaveganGenerator(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=30,
|
||||
stacks=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
aux_context_window=2,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
use_causal_conv=False,
|
||||
upsample_conditional_features=True,
|
||||
upsample_factors=[4, 4, 4, 4])
|
||||
dummy_c = torch.rand((4, 80, 64))
|
||||
output = model(dummy_c)
|
||||
assert np.all(output.shape == (4, 1, 64 * 256))
|
||||
model.remove_weight_norm()
|
||||
output = model.inference(dummy_c)
|
||||
assert np.all(output.shape == (4, 1, (64 + 4) * 256))
|
|
@ -0,0 +1,12 @@
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
|
||||
|
||||
def test_melgan_generator():
|
||||
hop_length = 256
|
||||
model = MelganGenerator()
|
||||
dummy_input = tf.random.uniform((4, 80, 64))
|
||||
output = model(dummy_input, training=False)
|
||||
assert np.all(output.shape == (4, 1, 64 * hop_length)), output.shape
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
import soundfile as sf
|
||||
from librosa.core import load
|
||||
|
||||
from TTS.tests import get_tests_path, get_tests_input_path
|
||||
from TTS.vocoder.tf.layers.pqmf import PQMF
|
||||
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
|
||||
|
||||
def test_pqmf():
|
||||
w, sr = load(WAV_FILE)
|
||||
|
||||
layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
|
||||
w, sr = load(WAV_FILE)
|
||||
w2 = tf.convert_to_tensor(w[None, None, :])
|
||||
b2 = layer.analysis(w2)
|
||||
w2_ = layer.synthesis(b2)
|
||||
w2_ = w2.numpy()
|
||||
|
||||
print(w2_.max())
|
||||
print(w2_.min())
|
||||
print(w2_.mean())
|
||||
sf.write('tf_pqmf_output.wav', w2_.flatten(), sr)
|
||||
|
|
@ -124,6 +124,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
y_hat_vis = y_hat
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
# run D with or without cond. features
|
||||
|
@ -146,8 +147,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
else:
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
|
@ -328,6 +327,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
|
@ -349,8 +349,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
else:
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
feats_fake, feats_real = None, None
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
|
|
|
@ -67,14 +67,34 @@ def setup_generator(c):
|
|||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'parallel_wavegan_generator':
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'],
|
||||
stacks=c.generator_model_params['stacks'],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=c.audio['num_mels'],
|
||||
aux_context_window=c['conv_pad'],
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_conditional_features=True,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'])
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model))
|
||||
if 'parallel_wavegan' in c.discriminator_model:
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.parallel_wavegan_discriminator')
|
||||
else:
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in 'random_window_discriminator':
|
||||
model = MyModel(
|
||||
cond_channels=c.audio['num_mels'],
|
||||
|
@ -95,6 +115,33 @@ def setup_discriminator(c):
|
|||
max_channels=c.discriminator_model_params['max_channels'],
|
||||
downsample_factors=c.
|
||||
discriminator_model_params['downsample_factors'])
|
||||
if c.discriminator_model == 'residual_parallel_wavegan_discriminator':
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params['num_layers'],
|
||||
stacks=c.discriminator_model_params['stacks'],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
)
|
||||
if c.discriminator_model == 'parallel_wavegan_discriminator':
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params['num_layers'],
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue