Ton of linter fixes

This commit is contained in:
erogol 2020-06-02 18:58:10 +02:00
parent 58117c0c6e
commit d6b9a5738a
16 changed files with 60 additions and 231 deletions

View File

@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (count_parameters, create_experiment_folder,
from TTS.utils.io import (save_best_model, save_checkpoint,
load_config, copy_config_file)
from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
gradual_training_scheduler, set_weight_decay)
gradual_training_scheduler, set_weight_decay,
setup_torch_training_env)
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
@ -32,13 +33,8 @@ from TTS.datasets.preprocess import load_meta_data
from TTS.utils.radam import RAdam
from TTS.utils.measures import alignment_diagonal_score
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(54321)
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
use_cuda, num_gpus = setup_torch_training_env(True, False)
def setup_loader(ap, r, is_val=False, verbose=False):

View File

@ -2,6 +2,17 @@ import torch
import numpy as np
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
torch.backends.cudnn.enabled = cudnn_enable
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.manual_seed(54321)
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
return use_cuda, num_gpus
def check_update(model, grad_clip, ignore_stopnet=False):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False

View File

@ -3,29 +3,10 @@ import glob
import torch
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
from multiprocessing import Manager
def create_dataloader(hp, args, train):
dataset = MelFromDisk(hp, args, train)
if train:
return DataLoader(dataset=dataset,
batch_size=hp.train.batch_size,
shuffle=True,
num_workers=hp.train.num_workers,
pin_memory=True,
drop_last=True)
else:
return DataLoader(dataset=dataset,
batch_size=1,
shuffle=False,
num_workers=hp.train.num_workers,
pin_memory=True,
drop_last=False)
class GANDataset(Dataset):
"""
GAN Dataset searchs for all the wav files under root path
@ -55,26 +36,26 @@ class GANDataset(Dataset):
self.return_segments = return_segments
self.use_cache = use_cache
self.use_noise_augment = use_noise_augment
self.verbose = verbose
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
# map G and D instances
self.G_to_D_mappings = [i for i in range(len(self.item_list))]
self.G_to_D_mappings = list(range(len(self.item_list)))
self.shuffle_mapping()
# cache acoustic features
if use_cache:
self.create_feature_cache()
def create_feature_cache(self):
self.manager = Manager()
self.cache = self.manager.list()
self.cache += [None for _ in range(len(self.item_list))]
def find_wav_files(self, path):
@staticmethod
def find_wav_files(path):
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
def __len__(self):
@ -88,7 +69,6 @@ class GANDataset(Dataset):
item1 = self.load_item(idx)
item2 = self.load_item(idx2)
return item1, item2
else:
item1 = self.load_item(idx)
return item1

View File

@ -51,13 +51,13 @@ class STFTLoss(nn.Module):
class MultiScaleSTFTLoss(torch.nn.Module):
def __init__(self,
n_ffts=[1024, 2048, 512],
hop_lengths=[120, 240, 50],
win_lengths=[600, 1200, 240]):
n_ffts=(1024, 2048, 512),
hop_lengths=(120, 240, 50),
win_lengths=(600, 1200, 240)):
super(MultiScaleSTFTLoss, self).__init__()
self.loss_funcs = torch.nn.ModuleList()
for idx in range(len(n_ffts)):
self.loss_funcs.append(STFTLoss(n_ffts[idx], hop_lengths[idx], win_lengths[idx]))
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
def forward(self, y_hat, y):
N = len(self.loss_funcs)
@ -81,20 +81,14 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """
def __init__(self,):
super(MSEGLoss, self).__init__()
def forward(self, score_fake, ):
def forward(self, score_fake):
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
return loss_fake
class HingeGLoss(nn.Module):
""" Hinge Discriminator Loss """
def __init__(self,):
super(HingeGLoss, self).__init__()
def forward(self, score_fake, score_real):
def forward(self, score_fake):
loss_fake = torch.mean(F.relu(1. + score_fake))
return loss_fake
@ -106,9 +100,6 @@ class HingeGLoss(nn.Module):
class MSEDLoss(nn.Module):
""" Mean Squared Discriminator Loss """
def __init__(self,):
super(MSEDLoss, self).__init__()
def forward(self, score_fake, score_real):
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
@ -118,9 +109,6 @@ class MSEDLoss(nn.Module):
class HingeDLoss(nn.Module):
""" Hinge Discriminator Loss """
def __init__(self,):
super(HingeDLoss, self).__init__()
def forward(self, score_fake, score_real):
loss_real = torch.mean(F.relu(1. - score_real))
loss_fake = torch.mean(F.relu(1. + score_fake))
@ -129,13 +117,10 @@ class HingeDLoss(nn.Module):
class MelganFeatureLoss(nn.Module):
def __init__(self, ):
super(MelganFeatureLoss, self).__init__()
def forward(self, fake_feats, real_feats):
loss_feats = 0
for fake_feat, real_feat in zip(fake_feats, real_feats):
loss_feats += hp.model.feat_match * torch.mean(torch.abs(fake_feat - real_feat))
loss_feats += torch.mean(torch.abs(fake_feat - real_feat))
return loss_feats
@ -147,7 +132,7 @@ class MelganFeatureLoss(nn.Module):
class GeneratorLoss(nn.Module):
def __init__(self, C):
super(GeneratorLoss, self).__init__()
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_stft_loss = C.use_stft_loss
@ -228,7 +213,7 @@ class GeneratorLoss(nn.Module):
class DiscriminatorLoss(nn.Module):
def __init__(self, C):
super(DiscriminatorLoss, self).__init__()
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_mse_gan_loss = C.use_mse_gan_loss

View File

@ -1,7 +1,4 @@
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import weight_norm

View File

@ -44,6 +44,9 @@ class PQMF(torch.nn.Module):
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def forward(self, x):
return self.analysis(x)
def analysis(self, x):
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)

View File

@ -1,127 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Pseudo QMF modules."""
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import kaiser
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
"""Design prototype filter for PQMF.
This method is based on `A Kaiser window approach for the design of prototype
filters of cosine modulated filterbanks`_.
Args:
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
Returns:
ndarray: Impluse response of prototype filter (taps + 1,).
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
https://ieeexplore.ieee.org/abstract/document/681427
"""
# check the arguments are valid
assert taps % 2 == 0, "The number of taps mush be even number."
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
# make initial filter
omega_c = np.pi * cutoff_ratio
with np.errstate(invalid='ignore'):
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
/ (np.pi * (np.arange(taps + 1) - 0.5 * taps))
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
# apply kaiser window
w = kaiser(taps + 1, beta)
h = h_i * w
return h
class PQMF(torch.nn.Module):
"""PQMF module.
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
https://ieeexplore.ieee.org/document/258122
"""
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
"""Initilize PQMF module.
Args:
subbands (int): The number of subbands.
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
"""
super(PQMF, self).__init__()
# define filter coefficient
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + (-1) ** k * np.pi / 4)
h_synthesis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) - (-1) ** k * np.pi / 4)
# convert to tensor
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
# register coefficients as beffer
self.register_buffer("analysis_filter", analysis_filter)
self.register_buffer("synthesis_filter", synthesis_filter)
# filter for downsampling & upsampling
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
for k in range(subbands):
updown_filter[k, k, 0] = 1.0
self.register_buffer("updown_filter", updown_filter)
self.subbands = subbands
# keep padding info
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def analysis(self, x):
"""Analysis with PQMF.
Args:
x (Tensor): Input tensor (B, 1, T).
Returns:
Tensor: Output tensor (B, subbands, T // subbands).
"""
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
return F.conv1d(x, self.updown_filter, stride=self.subbands)
def synthesis(self, x):
"""Synthesis with PQMF.
Args:
x (Tensor): Input tensor (B, subbands, T // subbands).
Returns:
Tensor: Output tensor (B, 1, T).
"""
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
# Not sure this is the correct way, it is better to check again.
# TODO(kan-bayashi): Understand the reconstruction procedure
x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
x = F.conv1d(self.pad_fn(x), self.synthesis_filter)
return x

View File

@ -1,7 +1,5 @@
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import weight_norm
@ -32,7 +30,7 @@ class MelganDiscriminator(nn.Module):
# downsampling layers
layer_in_channels = base_channels
for idx, downsample_factor in enumerate(downsample_factors):
for downsample_factor in downsample_factors:
layer_out_channels = min(layer_in_channels * downsample_factor,
max_channels)
layer_kernel_size = downsample_factor * 10 + 1

View File

@ -1,7 +1,5 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import weight_norm
from TTS.vocoder.layers.melgan import ResidualStack

View File

@ -22,7 +22,7 @@ ok_ljspeech = os.path.exists(test_data_path)
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers):
''' run dataloader with given parameters and check conditions '''
ap = AudioProcessor(**C.audio)
eval_items, train_items = load_wav_data(test_data_path, 10)
_, train_items = load_wav_data(test_data_path, 10)
dataset = GANDataset(ap,
train_items,
seq_len=seq_len,
@ -44,9 +44,9 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
# return random segments or return the whole audio
if return_segments:
for item1, item2 in loader:
for item1, _ in loader:
feat1, wav1 = item1
feat2, wav2 = item2
# feat2, wav2 = item2
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
# check shapes

View File

@ -1,5 +1,4 @@
import os
import unittest
import torch
from TTS.vocoder.layers.losses import TorchSTFT, STFTLoss, MultiScaleSTFTLoss
@ -24,7 +23,7 @@ def test_torch_stft():
torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length)
# librosa stft
wav = ap.load_wav(WAV_FILE)
M_librosa = abs(ap._stft(wav))
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
# torch stft
wav = torch.from_numpy(wav[None, :]).float()
M_torch = torch_stft(wav)

View File

@ -1,5 +1,4 @@
import numpy as np
import unittest
import torch
from TTS.vocoder.models.melgan_generator import MelganGenerator

View File

@ -1,16 +1,11 @@
import os
import torch
import unittest
import numpy as np
import soundfile as sf
from librosa.core import load
from TTS.tests import get_tests_path, get_tests_input_path, get_tests_output_path
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.tests import get_tests_path, get_tests_input_path
from TTS.vocoder.layers.pqmf import PQMF
from TTS.vocoder.layers.pqmf2 import PQMF as PQMF2
TESTS_PATH = get_tests_path()

View File

@ -17,7 +17,7 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import NoamLR
from TTS.utils.training import setup_torch_training_env, NoamLR
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data
# from distribute import (DistributedSampler, apply_gradient_allreduce,
@ -29,13 +29,8 @@ from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
setup_discriminator,
setup_generator)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.manual_seed(54321)
use_cuda = torch.cuda.is_available()
num_gpus = torch.cuda.device_count()
print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False):
@ -49,10 +44,11 @@ def setup_loader(ap, is_val=False, verbose=False):
pad_short=c.pad_short,
conv_pad=c.conv_pad,
is_training=not is_val,
return_segments=False if is_val else True,
return_segments=not is_val,
use_noise_augment=c.use_noise_augment,
use_cache=c.use_cache,
verbose=verbose)
dataset.shuffle_mapping()
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=1 if is_val else c.batch_size,
@ -81,11 +77,11 @@ def format_data(data):
return c_G, x_G, c_D, x_D
# return a whole audio segment
c, x = data
co, x = data
if use_cuda:
c = c.cuda(non_blocking=True)
co = co.cuda(non_blocking=True)
x = x.cuda(non_blocking=True)
return c, x
return co, x, None, None
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
@ -306,7 +302,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
start_time = time.time()
# format data
c_G, y_G = format_data(data)
c_G, y_G, _, _ = format_data(data)
loader_time = time.time() - end_time
global_step += 1
@ -416,7 +412,7 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_gen.load_state_dict(checkpoint['optimizer'])
model_disc.load_state_dict(checkpoint['model_disc'])
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
except:
except KeyError:
print(" > Partial model initialization.")
model_dict = model_gen.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)

View File

@ -91,7 +91,7 @@ class ConsoleLogger():
sign = ''
elif diff > 0:
color = tcolors.FAIL
sing = '+'
sign = '+'
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
self.old_eval_loss_dict = avg_loss_dict
print(log_text, flush=True)

View File

@ -2,7 +2,6 @@ import os
import torch
import datetime
def save_model(model, optimizer, model_disc, optimizer_disc, current_step,
epoch, output_path, **kwargs):
model_state = model.state_dict()