pylint fix

This commit is contained in:
erogol 2020-04-23 15:46:45 +02:00
parent 3673cc1e30
commit f63bce89f6
7 changed files with 27 additions and 30 deletions

View File

@ -13,6 +13,7 @@ os.makedirs(OUT_PATH, exist_ok=True)
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json')) conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
# pylint: disable=protected-access
class TestAudio(unittest.TestCase): class TestAudio(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestAudio, self).__init__(*args, **kwargs) super(TestAudio, self).__init__(*args, **kwargs)

View File

@ -7,7 +7,6 @@ import traceback
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.datasets.TTSDataset import MyDataset from TTS.datasets.TTSDataset import MyDataset
@ -20,7 +19,7 @@ from TTS.utils.generic_utils import (
get_git_branch, load_config, remove_experiment_folder, save_best_model, get_git_branch, load_config, remove_experiment_folder, save_best_model,
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file, save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
setup_model, gradual_training_scheduler, KeepAverage, setup_model, gradual_training_scheduler, KeepAverage,
set_weight_decay, check_config, print_train_step) set_weight_decay, check_config)
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
@ -215,7 +214,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
update_train_values = { update_train_values = {
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()), 'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()), 'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
'avg_stopnet_loss': loss_dict['stopnet_loss'].item() 'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()), if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
'avg_step_time': step_time, 'avg_step_time': step_time,
'avg_loader_time': loader_time 'avg_loader_time': loader_time

View File

@ -72,7 +72,7 @@ class AudioProcessor(object):
# setup scaler # setup scaler
if stats_path: if stats_path:
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
self.setup_scaler(mel_mean, mel_std, linear_mean,linear_std) self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)
self.signal_norm = True self.signal_norm = True
self.max_norm = None self.max_norm = None
self.clip_norm = None self.clip_norm = None
@ -168,8 +168,8 @@ class AudioProcessor(object):
for key in stats_config.keys(): for key in stats_config.keys():
if key in skip_parameters: if key in skip_parameters:
continue continue
assert stats_config[key] == self.__dict__[ assert stats_config[key] == self.__dict__[key],\
key], f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
return mel_mean, mel_std, linear_mean, linear_std, stats_config return mel_mean, mel_std, linear_mean, linear_std, stats_config
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
@ -180,9 +180,11 @@ class AudioProcessor(object):
self.linear_scaler.set_stats(linear_mean, linear_std) self.linear_scaler.set_stats(linear_mean, linear_std)
### DB and AMP conversion ### ### DB and AMP conversion ###
# pylint: disable=no-self-use
def _amp_to_db(self, x): def _amp_to_db(self, x):
return 20 * np.log10(np.maximum(1e-5, x)) return 20 * np.log10(np.maximum(1e-5, x))
# pylint: disable=no-self-use
def _db_to_amp(self, x): def _db_to_amp(self, x):
return np.power(10.0, x * 0.05) return np.power(10.0, x * 0.05)
@ -269,14 +271,13 @@ class AudioProcessor(object):
y = self._istft(S_complex * angles) y = self._istft(S_complex * angles)
return y return y
def compute_stft_paddings(self,x, pad_sides=1): def compute_stft_paddings(self, x, pad_sides=1):
'''compute right padding (final frame) or both sides padding (first and final frames) '''compute right padding (final frame) or both sides padding (first and final frames)
''' '''
assert pad_sides in (1, 2) assert pad_sides in (1, 2)
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
if pad_sides == 1: if pad_sides == 1:
return 0, pad return 0, pad
else:
return pad // 2, pad // 2 + pad % 2 return pad // 2, pad // 2 + pad % 2
### Audio Processing ### ### Audio Processing ###

View File

@ -22,6 +22,7 @@ class ConsoleLogger():
self.old_epoch_loss_dict = None self.old_epoch_loss_dict = None
self.old_eval_loss_dict = None self.old_eval_loss_dict = None
# pylint: disable=no-self-use
def get_time(self): def get_time(self):
now = datetime.datetime.now() now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S") return now.strftime("%Y-%m-%d %H:%M:%S")
@ -47,10 +48,11 @@ class ConsoleLogger():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
else: else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value) log_text += "{}{}: {:.5f} \n".format(indent, key, value)
log_text += "{}avg_spec_len: {}\n{}avg_text_len: {}\n{}step_time: {:.2f}\n{}loader_time: {:.2f}\n{}lr: {:.5f}"\ log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\n{indent}\
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr) step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"
print(log_text, flush=True) print(log_text, flush=True)
# pylint: disable=unused-argument
def print_train_epoch_end(self, global_step, epoch, epoch_time, def print_train_epoch_end(self, global_step, epoch, epoch_time,
print_dict): print_dict):
indent = " | > " indent = " | > "

View File

@ -428,19 +428,12 @@ def print_train_step(batch_steps, step, global_step, avg_spec_length, avg_text_l
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC) log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
for key, value in print_dict.items(): for key, value in print_dict.items():
log_text += "{}{}: {:.5f}\n".format(indent, key, value) log_text += "{}{}: {:.5f}\n".format(indent, key, value)
log_text += "{}avg_spec_len: {}\n{}avg_text_len: {}\n{}step_time: {:.2f}\n{}loader_time: {:.2f}\n{}lr: {:.5f}"\ log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\
\n{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"\
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr) .format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
print(log_text, flush=True) print(log_text, flush=True)
def print_train_epoch(step, global_step, epoch, loss_dict):
pass
def print_eval_step():
pass
def check_config(c): def check_config(c):
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) _check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
_check_argument('run_name', c, restricted=True, val_type=str) _check_argument('run_name', c, restricted=True, val_type=str)
@ -530,6 +523,7 @@ def check_config(c):
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) _check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# dataloading # dataloading
# pylint: disable=import-outside-toplevel
from TTS.utils.text import cleaners from TTS.utils.text import cleaners
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) _check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) _check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)

View File

@ -40,7 +40,7 @@ def plot_spectrogram(linear_output, audio, fig_size=(16, 10)):
return fig return fig
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=[8, 24]): def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24)):
if decoder_output is not None: if decoder_output is not None:
num_plot = 4 num_plot = 4
else: else: