pylint fix

This commit is contained in:
erogol 2020-04-23 15:46:45 +02:00
parent 3673cc1e30
commit f63bce89f6
7 changed files with 27 additions and 30 deletions

View File

@ -224,7 +224,7 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets)
# compute linear spectrogram
# compute linear spectrogram
if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
linear = prepare_tensor(linear, self.outputs_per_step)

View File

@ -13,6 +13,7 @@ os.makedirs(OUT_PATH, exist_ok=True)
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
# pylint: disable=protected-access
class TestAudio(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(TestAudio, self).__init__(*args, **kwargs)
@ -165,7 +166,7 @@ class TestAudio(unittest.TestCase):
self.ap.signal_norm = False
self.ap.preemphasis = 0.0
# test scaler forward and backward transforms
wav = self.ap.load_wav(WAV_FILE)
mel_reference = self.ap.melspectrogram(wav)

View File

@ -7,7 +7,6 @@ import traceback
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from TTS.datasets.TTSDataset import MyDataset
@ -20,7 +19,7 @@ from TTS.utils.generic_utils import (
get_git_branch, load_config, remove_experiment_folder, save_best_model,
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
setup_model, gradual_training_scheduler, KeepAverage,
set_weight_decay, check_config, print_train_step)
set_weight_decay, check_config)
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
@ -215,7 +214,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
update_train_values = {
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
'avg_stopnet_loss': loss_dict['stopnet_loss'].item()
'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
'avg_step_time': step_time,
'avg_loader_time': loader_time
@ -591,8 +590,8 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Number of output frames:", model.decoder.r)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
optimizer_st, scheduler, ap,
global_step, epoch)
optimizer_st, scheduler, ap,
global_step, epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss']

View File

@ -72,7 +72,7 @@ class AudioProcessor(object):
# setup scaler
if stats_path:
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
self.setup_scaler(mel_mean, mel_std, linear_mean,linear_std)
self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)
self.signal_norm = True
self.max_norm = None
self.clip_norm = None
@ -107,7 +107,7 @@ class AudioProcessor(object):
# mean-var scaling
if hasattr(self, 'mel_scaler'):
if S.shape[0] == self.num_mels:
return self.mel_scaler.transform(S.T).T
return self.mel_scaler.transform(S.T).T
elif S.shape[0] == self.n_fft / 2:
return self.linear_scaler.transform(S.T).T
else:
@ -136,7 +136,7 @@ class AudioProcessor(object):
# mean-var scaling
if hasattr(self, 'mel_scaler'):
if S_denorm.shape[0] == self.num_mels:
return self.mel_scaler.inverse_transform(S_denorm.T).T
return self.mel_scaler.inverse_transform(S_denorm.T).T
elif S_denorm.shape[0] == self.n_fft / 2:
return self.linear_scaler.inverse_transform(S_denorm.T).T
else:
@ -168,10 +168,10 @@ class AudioProcessor(object):
for key in stats_config.keys():
if key in skip_parameters:
continue
assert stats_config[key] == self.__dict__[
key], f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
assert stats_config[key] == self.__dict__[key],\
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
return mel_mean, mel_std, linear_mean, linear_std, stats_config
# pylint: disable=attribute-defined-outside-init
def setup_scaler(self, mel_mean, mel_std, linear_mean, linear_std):
self.mel_scaler = StandardScaler()
@ -180,9 +180,11 @@ class AudioProcessor(object):
self.linear_scaler.set_stats(linear_mean, linear_std)
### DB and AMP conversion ###
# pylint: disable=no-self-use
def _amp_to_db(self, x):
return 20 * np.log10(np.maximum(1e-5, x))
# pylint: disable=no-self-use
def _db_to_amp(self, x):
return np.power(10.0, x * 0.05)
@ -269,15 +271,14 @@ class AudioProcessor(object):
y = self._istft(S_complex * angles)
return y
def compute_stft_paddings(self,x, pad_sides=1):
def compute_stft_paddings(self, x, pad_sides=1):
'''compute right padding (final frame) or both sides padding (first and final frames)
'''
assert pad_sides in (1, 2)
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
if pad_sides == 1:
return 0, pad
else:
return pad // 2, pad // 2 + pad % 2
return pad // 2, pad // 2 + pad % 2
### Audio Processing ###
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):

View File

@ -22,6 +22,7 @@ class ConsoleLogger():
self.old_epoch_loss_dict = None
self.old_eval_loss_dict = None
# pylint: disable=no-self-use
def get_time(self):
now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S")
@ -47,10 +48,11 @@ class ConsoleLogger():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
log_text += "{}avg_spec_len: {}\n{}avg_text_len: {}\n{}step_time: {:.2f}\n{}loader_time: {:.2f}\n{}lr: {:.5f}"\
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\n{indent}\
step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"
print(log_text, flush=True)
# pylint: disable=unused-argument
def print_train_epoch_end(self, global_step, epoch, epoch_time,
print_dict):
indent = " | > "

View File

@ -369,7 +369,7 @@ class KeepAverage():
return self.avg_values[key]
def items(self):
return self.avg_values.items()
return self.avg_values.items()
def add_value(self, name, init_val=0, init_iter=0):
self.avg_values[name] = init_val
@ -412,7 +412,7 @@ def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restric
tcolors = AttrDict({
'OKBLUE': '\033[94m',
'HEADER': '\033[95m',
'HEADER': '\033[95m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
@ -428,17 +428,10 @@ def print_train_step(batch_steps, step, global_step, avg_spec_length, avg_text_l
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
for key, value in print_dict.items():
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
log_text += "{}avg_spec_len: {}\n{}avg_text_len: {}\n{}step_time: {:.2f}\n{}loader_time: {:.2f}\n{}lr: {:.5f}"\
log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\
\n{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"\
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
print(log_text, flush=True)
def print_train_epoch(step, global_step, epoch, loss_dict):
pass
def print_eval_step():
pass
def check_config(c):
@ -530,6 +523,7 @@ def check_config(c):
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# dataloading
# pylint: disable=import-outside-toplevel
from TTS.utils.text import cleaners
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)

View File

@ -40,7 +40,7 @@ def plot_spectrogram(linear_output, audio, fig_size=(16, 10)):
return fig
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=[8, 24]):
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24)):
if decoder_output is not None:
num_plot = 4
else: