linter fixes

This commit is contained in:
erogol 2020-07-28 14:53:56 +02:00
parent ade5fc2675
commit c5905cfa50
18 changed files with 66 additions and 485 deletions

View File

@ -11,15 +11,13 @@ import numpy as np
import tensorflow as tf
import torch
from fuzzywuzzy import fuzz
from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config
from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, convert_tf_name, tf_create_dummy_inputs,
transfer_weights_torch_to_tf)
from TTS.tts.tf.utils.generic_utils import save_checkpoint
sys.path.append('/home/erogol/Projects')
os.environ['CUDA_VISIBLE_DEVICES'] = ''

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, sys
import os
import sys
import pathlib
import time
import subprocess
@ -37,8 +38,8 @@ def main():
group_id = time.strftime("%Y_%m_%d-%H%M%S")
# set arguments for train.py
folder_path =pathlib.Path(__file__).parent.absolute()
command = [os.path.join(folder_path,'train_tts.py')]
folder_path = pathlib.Path(__file__).parent.absolute()
command = [os.path.join(folder_path, 'train_tts.py')]
command.append('--continue_path={}'.format(args.continue_path))
command.append('--restore_path={}'.format(args.restore_path))
command.append('--config_path={}'.format(args.config_path))

View File

@ -1,109 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import librosa
import yaml
import shutil
import argparse
import matplotlib.pyplot as plt
import math, pickle, os, glob
import numpy as np
from tqdm import tqdm
from TTS.tts.utils.audio import AudioProcessor
from TTS.tts.utils.generic_utils import load_config
from multiprocessing import Pool
os.environ["OMP_NUM_THREADS"] = "1"
def get_files(path, extension=".wav"):
filenames = []
for filename in glob.iglob(f"{path}/**/*{extension}", recursive=True):
filenames += [filename]
return filenames
def _process_file(path):
wav = ap.load_wav(path)
mel = ap.melspectrogram(wav)
wav = wav.astype(np.float32)
# check
assert len(wav.shape) == 1, \
f"{path} seems to be multi-channel signal."
assert np.abs(wav).max() <= 1.0, \
f"{path} seems to be different from 16 bit PCM."
# gap when wav is not multiple of hop_length
gap = wav.shape[0] % ap.hop_length
assert mel.shape[1] * ap.hop_length == wav.shape[0] + ap.hop_length - gap, f'{mel.shape[1] * ap.hop_length} vs {wav.shape[0] + ap.hop_length + gap}'
return mel.astype(np.float32), wav
def extract_feats(wav_path):
idx = wav_path.split("/")[-1][:-4]
m, wav = _process_file(wav_path)
mel_path = f"{MEL_PATH}{idx}.npy"
np.save(mel_path, m.astype(np.float32), allow_pickle=False)
return wav_path, mel_path
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path", type=str, help="path to config file for feature extraction."
)
parser.add_argument(
"--num_procs", type=int, default=4, help="number of parallel processes."
)
parser.add_argument(
"--data_path", type=str, default='', help="path to audio files."
)
parser.add_argument(
"--out_path", type=str, default='', help="destination to write files."
)
parser.add_argument(
"--ignore_errors", type=bool, default=False, help="ignore bad files."
)
args = parser.parse_args()
# load config
config = load_config(args.config_path)
config.update(vars(args))
config.audio['do_trim_silence'] = False
# config['audio']['signal_norm'] = False # do not apply earlier normalization
ap = AudioProcessor(**config['audio'])
SEG_PATH = config['data_path']
OUT_PATH = args.out_path
MEL_PATH = os.path.join(OUT_PATH, "mel/")
os.makedirs(OUT_PATH, exist_ok=True)
os.makedirs(MEL_PATH, exist_ok=True)
# TODO: use TTS data processors
wav_files = get_files(SEG_PATH)
print(" > Number of audio files : {}".format(len(wav_files)))
wav_file = wav_files[0]
m, wav = _process_file(wav_file)
# sanity check
print(' > Sample Spec Stats...')
print(' | > spectrogram max:', m.max())
print(' | > spectrogram min: ', m.min())
print(' | > spectrogram shape:', m.shape)
print(' | > wav shape:', wav.shape)
print(' | > wav max - min:', wav.max(), ' - ', wav.min())
# This will take a while depending on size of dataset
#with Pool(args.num_procs) as p:
# dataset_ids = list(tqdm(p.imap(extract_feats, wav_files), total=len(wav_files)))
dataset_ids = []
for wav_file in tqdm(wav_files):
item_id = extract_feats(wav_file)
dataset_ids.append(item_id)
# save metadata
with open(os.path.join(OUT_PATH, "metadata.txt"), "w") as f:
for data in dataset_ids:
f.write(f"{data[0]}|{data[1]}\n")

View File

@ -20,7 +20,7 @@ from TTS.vocoder.utils.generic_utils import setup_generator
def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_id):
t_1 = time.time()
waveform, _, _, mel_postnet_spec, stop_tokens, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, use_gl)
waveform, _, _, mel_postnet_spec, _, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, use_gl)
if CONFIG.model == "Tacotron" and not use_gl:
mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T
if not use_gl:

View File

@ -9,19 +9,21 @@ import traceback
import torch
from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.speaker_encoder.dataset import MyDataset
from TTS.speaker_encoder.generic_utils import save_best_model
from TTS.speaker_encoder.loss import GE2ELoss
from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.visual import plot_embeddings
from TTS.speaker_encoder.generic_utils import save_best_model
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.utils.audio import AudioProcessor
from TTS.tts.utils.generic_utils import (create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from TTS.tts.utils.io import load_config, copy_config_file
from TTS.tts.utils.training import check_update, NoamLR
from TTS.tts.utils.tensorboard_logger import TensorboardLogger
from TTS.tts.utils.generic_utils import (create_experiment_folder,
get_git_branch,
remove_experiment_folder,
set_init_dict)
from TTS.tts.utils.io import copy_config_file, load_config
from TTS.tts.utils.radam import RAdam
from TTS.tts.utils.tensorboard_logger import TensorboardLogger
from TTS.tts.utils.training import NoamLR, check_update
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

View File

@ -1,34 +0,0 @@
import torch
import numpy as np
from torch.nn.utils.rnn import pad_sequence
class DurationCalculator():
def calculate_durations(self, att_ws, ilens, olens):
"""calculate duration from given alignment matrices"""
durations = [self._calculate_duration(att_w, ilen, olen) for att_w, ilen, olen in zip(att_ws, ilens, olens)]
return pad_sequence(durations, batch_first=True)
@staticmethod
def _calculate_duration(att_w, ilen, olen):
'''
attw : batch x outs x ins
'''
durations = torch.stack([att_w[:olen, :ilen].argmax(-1).eq(i).sum() for i in range(ilen)])
return durations
def calculate_scores(self, att_ws, ilens, olens):
"""calculate scores per duration step"""
scores = [self._calculate_scores(att_w, ilen, olen, self.K) for att_w, ilen, olen in zip(att_ws, ilens, olens)]
return pad_list(scores, 0)
@staticmethod
def _calculate_scores(att_w, ilen, olen, k):
# which input is attended for each output
scores = [None] * ilen
values, idxs = att_w[:olen, :ilen].max(-1)
for i in range(ilen):
vals = values[torch.where(idxs == i)]
scores[i] = vals
scores = [torch.nn.functional.pad(score, (0, k - score.shape[0])) for score in scores]
return torch.stack(scores)

View File

@ -1,9 +1,4 @@
import os
import glob
import torch
import shutil
import datetime
import subprocess
import importlib
import numpy as np
from collections import Counter

View File

@ -1,247 +0,0 @@
"""
BSD 3-Clause License
Copyright (c) 2017, Prem Seetharaman
All rights reserved.
* Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from this
software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
from librosa.util import pad_center, tiny, normalize
from librosa.filters import mel as librosa_mel_fn
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
n_fft=800, dtype=np.float32, norm=None):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = normalize(win_sq, norm=norm)**2
win_sq = pad_center(win_sq, n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x
def amp_to_db(x):
o = 20 * torch.log10(torch.clamp(x, min=1e-5))
return o
def db_to_amp(x):
o = torch.pow(x * 0.05, 10.0)
return o
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length=800, hop_length=200, win_length=800,
window='hann', padding_mode='reflect', use_cuda=False):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.padding_mode = padding_mode
self.use_cuda = use_cuda
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
assert(filter_length >= win_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer('forward_basis', forward_basis.float())
self.register_buffer('inverse_basis', inverse_basis.float())
def transform(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
input_data = F.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode=self.padding_mode)
input_data = input_data.squeeze(1)
# https://github.com/NVIDIA/tacotron2/issues/125
if self.use_cuda:
forward_transform = F.conv1d(
input_data.cuda(),
Variable(self.forward_basis, requires_grad=False).cuda(),
stride=self.hop_length,
padding=0).cpu()
else:
forward_transform = F.conv1d(
input_data,
Variable(self.forward_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(
torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
def inverse(self, magnitude, phase):
recombine_magnitude_phase = torch.cat(
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
if self.window is not None:
window_sum = window_sumsquare(
self.window, magnitude.size(-1), hop_length=self.hop_length,
win_length=self.win_length, n_fft=self.filter_length,
dtype=np.float32)
# remove modulation effects
approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0])
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False)
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
return inverse_transform
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=None, padding_mode='constant'):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length, padding_mode=padding_mode)
mel_basis = librosa_mel_fn(
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer('mel_basis', mel_basis)
def spectral_normalize(self, magnitudes):
output = amp_to_db(magnitudes)
return output
def spectral_de_normalize(self, magnitudes):
output = db_to_amp(magnitudes)
return output
def mel_spectrogram(self, y):
"""Computes mel-spectrograms from a batch of waves
PARAMS
------
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
RETURNS
-------
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
"""
assert(torch.min(y.data) >= -1)
assert(torch.max(y.data) <= 1)
magnitudes, phases = self.stft_fn.transform(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
return mel_output

View File

@ -1,32 +0,0 @@
import os
import subprocess
import tempfile
import nbformat
def _notebook_run(path):
"""Execute a notebook via nbconvert and collect output.
:returns (parsed nb object, execution errors)
"""
dirname, filename = os.path.split(path)
os.chdir(dirname)
with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout:
args = ["jupyter", "nbconvert", "--to", "notebook", "--execute",
"--ExecutePreprocessor.timeout=60",
"--output", fout.name, filename]
subprocess.check_call(args)
fout.seek(0)
nb = nbformat.read(fout, nbformat.current_nbformat)
errors = [output for cell in nb.cells if "outputs" in cell
for output in cell["outputs"]\
if output.output_type == "error"]
return nb, errors
def test_ipynb(path):
nb, errors = _notebook_run(path)
assert errors == []

View File

@ -31,14 +31,13 @@ def _expand_dollars(m):
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
if dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
if cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
return 'zero dollars'
def _expand_ordinal(m):

View File

@ -1,12 +1,8 @@
import os
import glob
import torch
import shutil
import datetime
import subprocess
import importlib
import numpy as np
from collections import Counter
def get_git_branch():

View File

@ -4,7 +4,6 @@ from torch.nn import functional as F
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(self,
kernel_size=3,
res_channels=64,
@ -14,32 +13,45 @@ class ResidualBlock(torch.nn.Module):
dropout=0.0,
dilation=1,
bias=True,
use_causal_conv=False
):
use_causal_conv=False):
super(ResidualBlock, self).__init__()
self.dropout = dropout
# no future time stamps available
if use_causal_conv:
padding = (kernel_size - 1) * dilation
else:
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert (kernel_size -
1) % 2 == 0, "Not support even number kernel size."
padding = (kernel_size - 1) // 2 * dilation
self.use_causal_conv = use_causal_conv
# dilation conv
self.conv = torch.nn.Conv1d(res_channels, gate_channels, kernel_size,
padding=padding, dilation=dilation, bias=bias)
self.conv = torch.nn.Conv1d(res_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias=bias)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
self.conv1x1_aux = torch.nn.Conv1d(aux_channels,
gate_channels,
1,
bias=False)
else:
self.conv1x1_aux = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels,
res_channels,
1,
bias=bias)
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels,
skip_channels,
1,
bias=bias)
def forward(self, x, c):
"""
@ -70,6 +82,6 @@ class ResidualBlock(torch.nn.Module):
s = self.conv1x1_skip(x)
# for residual connection
x = (self.conv1x1_out(x) + residual) * (0.5 ** 2)
x = (self.conv1x1_out(x) + residual) * (0.5**2)
return x, s

View File

@ -1,4 +1,3 @@
import numpy as np
import torch
from torch.nn import functional as F
@ -20,6 +19,7 @@ class Stretch2d(torch.nn.Module):
class UpsampleNetwork(torch.nn.Module):
# pylint: disable=dangerous-default-value
def __init__(self,
upsample_factors,
nonlinear_activation=None,
@ -64,6 +64,7 @@ class UpsampleNetwork(torch.nn.Module):
class ConvUpsample(torch.nn.Module):
# pylint: disable=dangerous-default-value
def __init__(self,
upsample_factors,
nonlinear_activation=None,

View File

@ -1,7 +1,6 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
@ -12,7 +11,7 @@ class ParallelWaveganDiscriminator(nn.Module):
of predictions.
It is a stack of convolutional blocks with dilation.
"""
# pylint: disable=dangerous-default-value
def __init__(self,
in_channels=1,
out_channels=1,
@ -37,10 +36,15 @@ class ParallelWaveganDiscriminator(nn.Module):
conv_in_channels = conv_channels
padding = (kernel_size - 1) // 2 * dilation
conv_layer = [
nn.Conv1d(conv_in_channels, conv_channels,
kernel_size=kernel_size, padding=padding,
dilation=dilation, bias=bias),
getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
nn.Conv1d(conv_in_channels,
conv_channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
bias=bias),
getattr(nn,
nonlinear_activation)(inplace=True,
**nonlinear_activation_params)
]
self.conv_layers += conv_layer
padding = (kernel_size - 1) // 2
@ -62,7 +66,7 @@ class ParallelWaveganDiscriminator(nn.Module):
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)
@ -77,6 +81,7 @@ class ParallelWaveganDiscriminator(nn.Module):
class ResidualParallelWaveganDiscriminator(nn.Module):
# pylint: disable=dangerous-default-value
def __init__(self,
in_channels=1,
out_channels=1,
@ -177,7 +182,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module):
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m)
self.apply(_apply_weight_norm)

View File

@ -1,7 +1,6 @@
import math
import numpy as np
import torch
from torch.nn.utils import weight_norm
from TTS.vocoder.layers.parallel_wavegan import ResidualBlock
from TTS.vocoder.layers.upsample import ConvUpsample
@ -13,6 +12,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
It is conditioned on an aux feature (spectrogram) to generate
an output waveform from an input noise.
"""
# pylint: disable=dangerous-default-value
def __init__(self,
in_channels=1,
out_channels=1,
@ -23,13 +23,9 @@ class ParallelWaveganGenerator(torch.nn.Module):
gate_channels=128,
skip_channels=64,
aux_channels=80,
aux_context_window=2,
dropout=0.0,
bias=True,
use_weight_norm=True,
use_causal_conv=False,
upsample_conditional_features=True,
upsample_net="ConvInUpsampleNetwork",
upsample_factors=[4, 4, 4, 4],
inference_padding=2):
@ -140,8 +136,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
def apply_weight_norm(self):
def _apply_weight_norm(m):
if isinstance(m, torch.nn.Conv1d) or isinstance(
m, torch.nn.Conv2d):
if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)):
torch.nn.utils.weight_norm(m)
# print(f"Weight norm is applied to {m}.")

View File

@ -69,7 +69,7 @@ def setup_generator(c):
num_res_blocks=c.generator_model_params['num_res_blocks'])
if c.generator_model in 'parallel_wavegan_generator':
model = MyModel(
in_channels=1,
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=c.generator_model_params['num_res_blocks'],
@ -90,10 +90,11 @@ def setup_generator(c):
def setup_discriminator(c):
print(" > Discriminator Model: {}".format(c.discriminator_model))
if 'parallel_wavegan' in c.discriminator_model:
MyModel = importlib.import_module('TTS.vocoder.models.parallel_wavegan_discriminator')
MyModel = importlib.import_module(
'TTS.vocoder.models.parallel_wavegan_discriminator')
else:
MyModel = importlib.import_module('TTS.vocoder.models.' +
c.discriminator_model.lower())
c.discriminator_model.lower())
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
if c.discriminator_model in 'random_window_discriminator':
model = MyModel(

View File

@ -15,12 +15,9 @@ def test_pwgan_generator():
gate_channels=128,
skip_channels=64,
aux_channels=80,
aux_context_window=2,
dropout=0.0,
bias=True,
use_weight_norm=True,
use_causal_conv=False,
upsample_conditional_features=True,
upsample_factors=[4, 4, 4, 4])
dummy_c = torch.rand((2, 80, 5))
output = model(dummy_c)

View File

@ -3,10 +3,11 @@ import tensorflow as tf
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
def test_melgan_generator():
hop_length = 256
model = MelganGenerator()
# pylint: disable=no-value-for-parameter
dummy_input = tf.random.uniform((4, 80, 64))
output = model(dummy_input, training=False)
assert np.all(output.shape == (4, 1, 64 * hop_length)), output.shape