From 35ad3090a1c3682915ec6de5ba8d5a07cae954be Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 14 Jul 2020 10:38:10 +0200 Subject: [PATCH 1/2] parallel wavegan initial implementation --- vocoder/configs/parallel_wavegan_config.json | 143 +++++++++++++ vocoder/layers/parallel_wavegan.py | 75 +++++++ vocoder/layers/upsample.py | 100 +++++++++ .../models/parallel_wavegan_discriminator.py | 192 ++++++++++++++++++ vocoder/models/parallel_wavegan_generator.py | 162 +++++++++++++++ .../test_parallel_wavegan_discriminator.py | 41 ++++ .../tests/test_parallel_wavegan_generator.py | 30 +++ vocoder/tests/test_tf_melgan_generator.py | 12 ++ vocoder/tests/test_tf_pqmf.py | 29 +++ vocoder/utils/generic_utils.py | 53 ++++- 10 files changed, 834 insertions(+), 3 deletions(-) create mode 100644 vocoder/configs/parallel_wavegan_config.json create mode 100644 vocoder/layers/parallel_wavegan.py create mode 100644 vocoder/layers/upsample.py create mode 100644 vocoder/models/parallel_wavegan_discriminator.py create mode 100644 vocoder/models/parallel_wavegan_generator.py create mode 100644 vocoder/tests/test_parallel_wavegan_discriminator.py create mode 100644 vocoder/tests/test_parallel_wavegan_generator.py create mode 100644 vocoder/tests/test_tf_melgan_generator.py create mode 100644 vocoder/tests/test_tf_pqmf.py diff --git a/vocoder/configs/parallel_wavegan_config.json b/vocoder/configs/parallel_wavegan_config.json new file mode 100644 index 00000000..fcd765bd --- /dev/null +++ b/vocoder/configs/parallel_wavegan_config.json @@ -0,0 +1,143 @@ +{ + "run_name": "pwgan", + "run_description": "parallel-wavegan training", + + // AUDIO PARAMETERS + "audio":{ + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 0, // reference level db, theoretically 20db is the sound of air. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 1.0, // scaler value appplied after log transform of spectrogram. + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // DISTRIBUTED TRAINING + // "distributed":{ + // "backend": "nccl", + // "url": "tcp:\/\/localhost:54321" + // }, + + // MODEL PARAMETERS + "use_pqmf": true, + + // LOSS PARAMETERS + "use_stft_loss": true, + "use_subband_stft_loss": false, // USE ONLY WITH MULTIBAND MODELS + "use_mse_gan_loss": true, + "use_hinge_gan_loss": false, + "use_feat_match_loss": false, // use only with melgan discriminators + + // loss weights + "stft_loss_weight": 0.5, + "subband_stft_loss_weight": 0.5, + "mse_G_loss_weight": 2.5, + "hinge_G_loss_weight": 2.5, + "feat_match_loss_weight": 25, + + // multiscale stft loss parameters + "stft_loss_params": { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240] + }, + + // subband multiscale stft loss parameters + "subband_stft_loss_params":{ + "n_ffts": [384, 683, 171], + "hop_lengths": [30, 60, 10], + "win_lengths": [150, 300, 60] + }, + + "target_loss": "avg_G_loss", // loss value to pick the best model to save after each epoch + + // DISCRIMINATOR + "discriminator_model": "parallel_wavegan_discriminator", + "discriminator_model_params":{ + "num_layers": 10 + }, + "steps_to_start_discriminator": 200000, // steps required to start GAN trainining.1 + + // GENERATOR + "generator_model": "parallel_wavegan_generator", + "generator_model_params": { + "upsample_factors":[4, 4, 4, 4], + "stacks": 3, + "num_res_blocks": 30 + }, + + // DATASET + "data_path": "/home/erogol/Data/LJSpeech-1.1/wavs/", + "feature_path": null, + "seq_len": 25600, + "pad_short": 2000, + "conv_pad": 0, + "use_noise_augment": false, + "use_cache": true, + + "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. + + // TRAINING + "batch_size": 6, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + + // VALIDATION + "run_eval": true, + "test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time. + "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. + + // OPTIMIZER + "epochs": 10000, // total number of epochs to train. + "wd": 0.0, // Weight decay weight. + "gen_clip_grad": -1, // Generator gradient clipping threshold. Apply gradient clipping if > 0 + "disc_clip_grad": -1, // Discriminator gradient clipping threshold. + "lr_scheduler_gen": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate + "lr_scheduler_gen_params": { + "gamma": 0.5, + "milestones": [100000, 200000, 300000, 400000, 500000, 600000] + }, + "lr_scheduler_disc": "MultiStepLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate + "lr_scheduler_disc_params": { + "gamma": 0.5, + "milestones": [100000, 200000, 300000, 400000, 500000, 600000] + }, + "lr_gen": 1e-4, // Initial learning rate. If Noam decay is active, maximum learning rate. + "lr_disc": 1e-4, + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log traning on console. + "print_eval": false, // If True, it prints loss values for each step in eval run. + "save_step": 25000, // Number of training steps expected to plot training stats on TB and save model checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + + // DATA LOADING + "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 4, // number of evaluation data loader processes. + "eval_split_size": 10, + + // PATHS + "output_path": "/home/erogol/Models/LJSpeech/" +} + diff --git a/vocoder/layers/parallel_wavegan.py b/vocoder/layers/parallel_wavegan.py new file mode 100644 index 00000000..35a56e8d --- /dev/null +++ b/vocoder/layers/parallel_wavegan.py @@ -0,0 +1,75 @@ +import torch +from torch.nn import functional as F + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__(self, + kernel_size=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False + ): + super(ResidualBlock, self).__init__() + self.dropout = dropout + # no future time stamps available + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + padding = (kernel_size - 1) // 2 * dilation + self.use_causal_conv = use_causal_conv + + # dilation conv + self.conv = torch.nn.Conv1d(res_channels, gate_channels, kernel_size, + padding=padding, dilation=dilation, bias=bias) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False) + else: + self.conv1x1_aux = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias) + self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias) + + def forward(self, x, c): + """ + x: B x D_res x T + c: B x D_aux x T + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv(x) + + # remove future time steps if use_causal_conv conv + x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + assert self.conv1x1_aux is not None + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # for skip connection + s = self.conv1x1_skip(x) + + # for residual connection + x = (self.conv1x1_out(x) + residual) * (0.5 ** 2) + + return x, s diff --git a/vocoder/layers/upsample.py b/vocoder/layers/upsample.py new file mode 100644 index 00000000..1f70c9f6 --- /dev/null +++ b/vocoder/layers/upsample.py @@ -0,0 +1,100 @@ +import numpy as np +import torch +from torch.nn import functional as F + + +class Stretch2d(torch.nn.Module): + def __init__(self, x_scale, y_scale, mode="nearest"): + super(Stretch2d, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + """ + x (Tensor): Input tensor (B, C, F, T). + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + """ + return F.interpolate( + x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + + +class UpsampleNetwork(torch.nn.Module): + def __init__(self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): + super(UpsampleNetwork, self).__init__() + self.use_causal_conv = use_causal_conv + self.up_layers = torch.nn.ModuleList() + for scale in upsample_factors: + # interpolation layer + stretch = Stretch2d(scale, 1, interpolate_mode) + self.up_layers += [stretch] + + # conv layer + assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.up_layers += [conv] + + # nonlinear + if nonlinear_activation is not None: + nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + self.up_layers += [nonlinear] + + def forward(self, c): + """ + c : (B, C, T_in). + Tensor: (B, C, T_upsample) + """ + c = c.unsqueeze(1) # (B, 1, C, T) + for f in self.up_layers: + c = f(c) + return c.squeeze(1) # (B, C, T') + + +class ConvUpsample(torch.nn.Module): + def __init__(self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False + ): + super(ConvUpsample, self).__init__() + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + # To capture wide-context information in conditional features + kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + # NOTE(kan-bayashi): Here do not use padding because the input is already padded + self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) + self.upsample = UpsampleNetwork( + upsample_factors=upsample_factors, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv, + ) + + def forward(self, c): + """ + c : (B, C, T_in). + Tensor: (B, C, T_upsampled), + """ + c_ = self.conv_in(c) + c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ + return self.upsample(c) diff --git a/vocoder/models/parallel_wavegan_discriminator.py b/vocoder/models/parallel_wavegan_discriminator.py new file mode 100644 index 00000000..de03ccdb --- /dev/null +++ b/vocoder/models/parallel_wavegan_discriminator.py @@ -0,0 +1,192 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.vocoder.layers.parallel_wavegan import ResidualBlock + + +class ParallelWaveganDiscriminator(nn.Module): + """PWGAN discriminator as in https://arxiv.org/abs/1910.11480. + It classifies each audio window real/fake and returns a sequence + of predictions. + It is a stack of convolutional blocks with dilation. + """ + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ): + super(ParallelWaveganDiscriminator, self).__init__() + assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size." + assert dilation_factor > 0, " [!] dilation factor must be > 0." + self.conv_layers = nn.ModuleList() + conv_in_channels = in_channels + for i in range(num_layers - 1): + if i == 0: + dilation = 1 + else: + dilation = i if dilation_factor == 1 else dilation_factor ** i + conv_in_channels = conv_channels + padding = (kernel_size - 1) // 2 * dilation + conv_layer = [ + nn.Conv1d(conv_in_channels, conv_channels, + kernel_size=kernel_size, padding=padding, + dilation=dilation, bias=bias), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params) + ] + self.conv_layers += conv_layer + padding = (kernel_size - 1) // 2 + last_conv_layer = nn.Conv1d( + conv_in_channels, out_channels, + kernel_size=kernel_size, padding=padding, bias=bias) + self.conv_layers += [last_conv_layer] + self.apply_weight_norm() + + def forward(self, x): + """ + x : (B, 1, T). + Returns: + Tensor: (B, 1, T) + """ + for f in self.conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + # print(f"Weight norm is removed from {m}.") + nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + self.apply(_remove_weight_norm) + + +class ResidualParallelWaveganDiscriminator(nn.Module): + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ): + super(ResidualParallelWaveganDiscriminator, self).__init__() + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.stacks = stacks + self.kernel_size = kernel_size + self.res_factor = math.sqrt(1.0 / num_layers) + + # check the number of num_layers and stacks + assert num_layers % stacks == 0 + layers_per_stack = num_layers // stacks + + # define first convolution + self.first_conv = nn.Sequential( + nn.Conv1d(in_channels, + res_channels, + kernel_size=1, + padding=0, + dilation=1, + bias=True), + getattr(nn, nonlinear_activation)(inplace=True, + **nonlinear_activation_params), + ) + + # define residual blocks + self.conv_layers = nn.ModuleList() + for layer in range(num_layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=-1, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=False, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = nn.ModuleList([ + getattr(nn, nonlinear_activation)(inplace=True, + **nonlinear_activation_params), + nn.Conv1d(skip_channels, + skip_channels, + kernel_size=1, + padding=0, + dilation=1, + bias=True), + getattr(nn, nonlinear_activation)(inplace=True, + **nonlinear_activation_params), + nn.Conv1d(skip_channels, + out_channels, + kernel_size=1, + padding=0, + dilation=1, + bias=True), + ]) + + # apply weight norm + self.apply_weight_norm() + + def forward(self, x): + """ + x: (B, 1, T). + """ + x = self.first_conv(x) + + skips = 0 + for f in self.conv_layers: + x, h = f(x, None) + skips += h + skips *= self.res_factor + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + print(f"Weight norm is removed from {m}.") + nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) diff --git a/vocoder/models/parallel_wavegan_generator.py b/vocoder/models/parallel_wavegan_generator.py new file mode 100644 index 00000000..56316a41 --- /dev/null +++ b/vocoder/models/parallel_wavegan_generator.py @@ -0,0 +1,162 @@ +import math +import numpy as np +import torch +from torch.nn.utils import weight_norm + +from TTS.vocoder.layers.parallel_wavegan import ResidualBlock +from TTS.vocoder.layers.upsample import ConvUpsample + + +class ParallelWaveganGenerator(torch.nn.Module): + """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. + It is similar to WaveNet with no causal convolution. + It is conditioned on an aux feature (spectrogram) to generate + an output waveform from an input noise. + """ + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + aux_context_window=2, + dropout=0.0, + bias=True, + use_weight_norm=True, + use_causal_conv=False, + upsample_conditional_features=True, + upsample_net="ConvInUpsampleNetwork", + upsample_factors=[4, 4, 4, 4], + inference_padding=2): + + super(ParallelWaveganGenerator, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.num_res_blocks = num_res_blocks + self.stacks = stacks + self.kernel_size = kernel_size + self.upsample_factors = upsample_factors + self.upsample_scale = np.prod(upsample_factors) + self.inference_padding = inference_padding + + # check the number of layers and stacks + assert num_res_blocks % stacks == 0 + layers_per_stack = num_res_blocks // stacks + + # define first convolution + self.first_conv = torch.nn.Conv1d(in_channels, + res_channels, + kernel_size=1, + bias=True) + + # define conv + upsampling network + self.upsample_net = ConvUpsample(upsample_factors=upsample_factors) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(num_res_blocks): + dilation = 2**(layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList([ + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, + skip_channels, + kernel_size=1, + bias=True), + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, + out_channels, + kernel_size=1, + bias=True), + ]) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, c): + """ + c: (B, C ,T'). + o: Output tensor (B, out_channels, T) + """ + # random noise + x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale]) + x = x.to(self.first_conv.bias.device) + + # perform upsampling + if c is not None and self.upsample_net is not None: + c = self.upsample_net(c) + assert c.shape[-1] == x.shape[ + -1], f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" + + # encode to hidden representation + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, h = f(x, c) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + + return x + + def inference(self, c): + c = c.to(self.first_conv.weight.device) + c = torch.nn.functional.pad( + c, (self.inference_padding, self.inference_padding), 'replicate') + return self.forward(c) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + # print(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance( + m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + # print(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, + stacks, + kernel_size, + dilation=lambda x: 2**x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + return self._get_receptive_field_size(self.layers, self.stacks, + self.kernel_size) diff --git a/vocoder/tests/test_parallel_wavegan_discriminator.py b/vocoder/tests/test_parallel_wavegan_discriminator.py new file mode 100644 index 00000000..b496e216 --- /dev/null +++ b/vocoder/tests/test_parallel_wavegan_discriminator.py @@ -0,0 +1,41 @@ +import numpy as np +import torch + +from TTS.vocoder.models.parallel_wavegan_discriminator import ParallelWaveganDiscriminator, ResidualParallelWaveganDiscriminator + + +def test_pwgan_disciminator(): + model = ParallelWaveganDiscriminator( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True) + dummy_x = torch.rand((4, 1, 64 * 256)) + output = model(dummy_x) + assert np.all(output.shape == (4, 1, 64 * 256)) + model.remove_weight_norm() + + +def test_redisual_pwgan_disciminator(): + model = ResidualParallelWaveganDiscriminator( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}) + dummy_x = torch.rand((4, 1, 64 * 256)) + output = model(dummy_x) + assert np.all(output.shape == (4, 1, 64 * 256)) + model.remove_weight_norm() diff --git a/vocoder/tests/test_parallel_wavegan_generator.py b/vocoder/tests/test_parallel_wavegan_generator.py new file mode 100644 index 00000000..f904ed24 --- /dev/null +++ b/vocoder/tests/test_parallel_wavegan_generator.py @@ -0,0 +1,30 @@ +import numpy as np +import torch + +from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator + + +def test_pwgan_generator(): + model = ParallelWaveganGenerator( + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + aux_context_window=2, + dropout=0.0, + bias=True, + use_weight_norm=True, + use_causal_conv=False, + upsample_conditional_features=True, + upsample_factors=[4, 4, 4, 4]) + dummy_c = torch.rand((4, 80, 64)) + output = model(dummy_c) + assert np.all(output.shape == (4, 1, 64 * 256)) + model.remove_weight_norm() + output = model.inference(dummy_c) + assert np.all(output.shape == (4, 1, (64 + 4) * 256)) diff --git a/vocoder/tests/test_tf_melgan_generator.py b/vocoder/tests/test_tf_melgan_generator.py new file mode 100644 index 00000000..40a167a2 --- /dev/null +++ b/vocoder/tests/test_tf_melgan_generator.py @@ -0,0 +1,12 @@ +import numpy as np +import tensorflow as tf + +from TTS.vocoder.tf.models.melgan_generator import MelganGenerator + +def test_melgan_generator(): + hop_length = 256 + model = MelganGenerator() + dummy_input = tf.random.uniform((4, 80, 64)) + output = model(dummy_input, training=False) + assert np.all(output.shape == (4, 1, 64 * hop_length)), output.shape + diff --git a/vocoder/tests/test_tf_pqmf.py b/vocoder/tests/test_tf_pqmf.py new file mode 100644 index 00000000..75f00d5f --- /dev/null +++ b/vocoder/tests/test_tf_pqmf.py @@ -0,0 +1,29 @@ +import os +import tensorflow as tf + +import soundfile as sf +from librosa.core import load + +from TTS.tests import get_tests_path, get_tests_input_path +from TTS.vocoder.tf.layers.pqmf import PQMF + + +TESTS_PATH = get_tests_path() +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + + +def test_pqmf(): + w, sr = load(WAV_FILE) + + layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) + w, sr = load(WAV_FILE) + w2 = tf.convert_to_tensor(w[None, None, :]) + b2 = layer.analysis(w2) + w2_ = layer.synthesis(b2) + w2_ = w2.numpy() + + print(w2_.max()) + print(w2_.min()) + print(w2_.mean()) + sf.write('tf_pqmf_output.wav', w2_.flatten(), sr) + diff --git a/vocoder/utils/generic_utils.py b/vocoder/utils/generic_utils.py index 80c97f1a..031d299d 100644 --- a/vocoder/utils/generic_utils.py +++ b/vocoder/utils/generic_utils.py @@ -67,14 +67,34 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) + if c.generator_model in 'parallel_wavegan_generator': + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=c.generator_model_params['num_res_blocks'], + stacks=c.generator_model_params['stacks'], + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=c.audio['num_mels'], + aux_context_window=c['conv_pad'], + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_conditional_features=True, + upsample_factors=c.generator_model_params['upsample_factors']) return model def setup_discriminator(c): print(" > Discriminator Model: {}".format(c.discriminator_model)) - MyModel = importlib.import_module('TTS.vocoder.models.' + - c.discriminator_model.lower()) - MyModel = getattr(MyModel, to_camel(c.discriminator_model)) + if 'parallel_wavegan' in c.discriminator_model: + MyModel = importlib.import_module('TTS.vocoder.models.parallel_wavegan_discriminator') + else: + MyModel = importlib.import_module('TTS.vocoder.models.' + + c.discriminator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) if c.discriminator_model in 'random_window_discriminator': model = MyModel( cond_channels=c.audio['num_mels'], @@ -95,6 +115,33 @@ def setup_discriminator(c): max_channels=c.discriminator_model_params['max_channels'], downsample_factors=c. discriminator_model_params['downsample_factors']) + if c.discriminator_model == 'residual_parallel_wavegan_discriminator': + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params['num_layers'], + stacks=c.discriminator_model_params['stacks'], + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ) + if c.discriminator_model == 'parallel_wavegan_discriminator': + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params['num_layers'], + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True + ) return model From 2c51ffd6ddb029003c2a30b494dae8d191a7aaa9 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 15 Jul 2020 10:55:45 +0200 Subject: [PATCH 2/2] vocoder train.py update --- vocoder/train.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vocoder/train.py b/vocoder/train.py index fd44c470..dc081a5e 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -124,6 +124,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, y_hat_vis = y_hat y_G_sub = model_G.pqmf_analysis(y_G) + scores_fake, feats_fake, feats_real = None, None, None if global_step > c.steps_to_start_discriminator: # run D with or without cond. features @@ -146,8 +147,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, _, feats_real = D_out_real else: scores_fake = D_out_fake - else: - scores_fake, feats_fake, feats_real = None, None, None # compute losses loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, @@ -328,6 +327,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) y_G_sub = model_G.pqmf_analysis(y_G) + scores_fake, feats_fake, feats_real = None, None, None if global_step > c.steps_to_start_discriminator: if len(signature(model_D.forward).parameters) == 2: @@ -349,8 +349,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) _, feats_real = D_out_real else: scores_fake = D_out_fake - else: - scores_fake, feats_fake, feats_real = None, None, None + feats_fake, feats_real = None, None # compute losses loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,