pylint fix

This commit is contained in:
erogol 2020-04-23 15:46:45 +02:00
parent 3673cc1e30
commit f63bce89f6
7 changed files with 27 additions and 30 deletions

View File

@ -224,7 +224,7 @@ class MyDataset(Dataset):
mel_lengths = torch.LongTensor(mel_lengths) mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets) stop_targets = torch.FloatTensor(stop_targets)
# compute linear spectrogram # compute linear spectrogram
if self.compute_linear_spec: if self.compute_linear_spec:
linear = [self.ap.spectrogram(w).astype('float32') for w in wav] linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
linear = prepare_tensor(linear, self.outputs_per_step) linear = prepare_tensor(linear, self.outputs_per_step)

View File

@ -13,6 +13,7 @@ os.makedirs(OUT_PATH, exist_ok=True)
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json')) conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
# pylint: disable=protected-access
class TestAudio(unittest.TestCase): class TestAudio(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestAudio, self).__init__(*args, **kwargs) super(TestAudio, self).__init__(*args, **kwargs)
@ -165,7 +166,7 @@ class TestAudio(unittest.TestCase):
self.ap.signal_norm = False self.ap.signal_norm = False
self.ap.preemphasis = 0.0 self.ap.preemphasis = 0.0
# test scaler forward and backward transforms # test scaler forward and backward transforms
wav = self.ap.load_wav(WAV_FILE) wav = self.ap.load_wav(WAV_FILE)
mel_reference = self.ap.melspectrogram(wav) mel_reference = self.ap.melspectrogram(wav)

View File

@ -7,7 +7,6 @@ import traceback
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from TTS.datasets.TTSDataset import MyDataset from TTS.datasets.TTSDataset import MyDataset
@ -20,7 +19,7 @@ from TTS.utils.generic_utils import (
get_git_branch, load_config, remove_experiment_folder, save_best_model, get_git_branch, load_config, remove_experiment_folder, save_best_model,
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file, save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
setup_model, gradual_training_scheduler, KeepAverage, setup_model, gradual_training_scheduler, KeepAverage,
set_weight_decay, check_config, print_train_step) set_weight_decay, check_config)
from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.console_logger import ConsoleLogger from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
@ -215,7 +214,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
update_train_values = { update_train_values = {
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()), 'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()), 'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
'avg_stopnet_loss': loss_dict['stopnet_loss'].item() 'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()), if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
'avg_step_time': step_time, 'avg_step_time': step_time,
'avg_loader_time': loader_time 'avg_loader_time': loader_time
@ -591,8 +590,8 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Number of output frames:", model.decoder.r) print("\n > Number of output frames:", model.decoder.r)
train_avg_loss_dict, global_step = train(model, criterion, optimizer, train_avg_loss_dict, global_step = train(model, criterion, optimizer,
optimizer_st, scheduler, ap, optimizer_st, scheduler, ap,
global_step, epoch) global_step, epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss'] target_loss = train_avg_loss_dict['avg_postnet_loss']

View File

@ -72,7 +72,7 @@ class AudioProcessor(object):
# setup scaler # setup scaler
if stats_path: if stats_path:
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
self.setup_scaler(mel_mean, mel_std, linear_mean,linear_std) self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)
self.signal_norm = True self.signal_norm = True
self.max_norm = None self.max_norm = None
self.clip_norm = None self.clip_norm = None
@ -107,7 +107,7 @@ class AudioProcessor(object):
# mean-var scaling # mean-var scaling
if hasattr(self, 'mel_scaler'): if hasattr(self, 'mel_scaler'):
if S.shape[0] == self.num_mels: if S.shape[0] == self.num_mels:
return self.mel_scaler.transform(S.T).T return self.mel_scaler.transform(S.T).T
elif S.shape[0] == self.n_fft / 2: elif S.shape[0] == self.n_fft / 2:
return self.linear_scaler.transform(S.T).T return self.linear_scaler.transform(S.T).T
else: else:
@ -136,7 +136,7 @@ class AudioProcessor(object):
# mean-var scaling # mean-var scaling
if hasattr(self, 'mel_scaler'): if hasattr(self, 'mel_scaler'):
if S_denorm.shape[0] == self.num_mels: if S_denorm.shape[0] == self.num_mels:
return self.mel_scaler.inverse_transform(S_denorm.T).T return self.mel_scaler.inverse_transform(S_denorm.T).T
elif S_denorm.shape[0] == self.n_fft / 2: elif S_denorm.shape[0] == self.n_fft / 2:
return self.linear_scaler.inverse_transform(S_denorm.T).T return self.linear_scaler.inverse_transform(S_denorm.T).T
else: else:
@ -168,10 +168,10 @@ class AudioProcessor(object):
for key in stats_config.keys(): for key in stats_config.keys():
if key in skip_parameters: if key in skip_parameters:
continue continue
assert stats_config[key] == self.__dict__[ assert stats_config[key] == self.__dict__[key],\
key], f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
return mel_mean, mel_std, linear_mean, linear_std, stats_config return mel_mean, mel_std, linear_mean, linear_std, stats_config
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
def setup_scaler(self, mel_mean, mel_std, linear_mean, linear_std): def setup_scaler(self, mel_mean, mel_std, linear_mean, linear_std):
self.mel_scaler = StandardScaler() self.mel_scaler = StandardScaler()
@ -180,9 +180,11 @@ class AudioProcessor(object):
self.linear_scaler.set_stats(linear_mean, linear_std) self.linear_scaler.set_stats(linear_mean, linear_std)
### DB and AMP conversion ### ### DB and AMP conversion ###
# pylint: disable=no-self-use
def _amp_to_db(self, x): def _amp_to_db(self, x):
return 20 * np.log10(np.maximum(1e-5, x)) return 20 * np.log10(np.maximum(1e-5, x))
# pylint: disable=no-self-use
def _db_to_amp(self, x): def _db_to_amp(self, x):
return np.power(10.0, x * 0.05) return np.power(10.0, x * 0.05)
@ -269,15 +271,14 @@ class AudioProcessor(object):
y = self._istft(S_complex * angles) y = self._istft(S_complex * angles)
return y return y
def compute_stft_paddings(self,x, pad_sides=1): def compute_stft_paddings(self, x, pad_sides=1):
'''compute right padding (final frame) or both sides padding (first and final frames) '''compute right padding (final frame) or both sides padding (first and final frames)
''' '''
assert pad_sides in (1, 2) assert pad_sides in (1, 2)
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
if pad_sides == 1: if pad_sides == 1:
return 0, pad return 0, pad
else: return pad // 2, pad // 2 + pad % 2
return pad // 2, pad // 2 + pad % 2
### Audio Processing ### ### Audio Processing ###
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8): def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):

View File

@ -22,6 +22,7 @@ class ConsoleLogger():
self.old_epoch_loss_dict = None self.old_epoch_loss_dict = None
self.old_eval_loss_dict = None self.old_eval_loss_dict = None
# pylint: disable=no-self-use
def get_time(self): def get_time(self):
now = datetime.datetime.now() now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S") return now.strftime("%Y-%m-%d %H:%M:%S")
@ -47,10 +48,11 @@ class ConsoleLogger():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
else: else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value) log_text += "{}{}: {:.5f} \n".format(indent, key, value)
log_text += "{}avg_spec_len: {}\n{}avg_text_len: {}\n{}step_time: {:.2f}\n{}loader_time: {:.2f}\n{}lr: {:.5f}"\ log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\n{indent}\
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr) step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"
print(log_text, flush=True) print(log_text, flush=True)
# pylint: disable=unused-argument
def print_train_epoch_end(self, global_step, epoch, epoch_time, def print_train_epoch_end(self, global_step, epoch, epoch_time,
print_dict): print_dict):
indent = " | > " indent = " | > "

View File

@ -369,7 +369,7 @@ class KeepAverage():
return self.avg_values[key] return self.avg_values[key]
def items(self): def items(self):
return self.avg_values.items() return self.avg_values.items()
def add_value(self, name, init_val=0, init_iter=0): def add_value(self, name, init_val=0, init_iter=0):
self.avg_values[name] = init_val self.avg_values[name] = init_val
@ -412,7 +412,7 @@ def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restric
tcolors = AttrDict({ tcolors = AttrDict({
'OKBLUE': '\033[94m', 'OKBLUE': '\033[94m',
'HEADER': '\033[95m', 'HEADER': '\033[95m',
'OKGREEN': '\033[92m', 'OKGREEN': '\033[92m',
'WARNING': '\033[93m', 'WARNING': '\033[93m',
'FAIL': '\033[91m', 'FAIL': '\033[91m',
@ -428,17 +428,10 @@ def print_train_step(batch_steps, step, global_step, avg_spec_length, avg_text_l
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC) log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
for key, value in print_dict.items(): for key, value in print_dict.items():
log_text += "{}{}: {:.5f}\n".format(indent, key, value) log_text += "{}{}: {:.5f}\n".format(indent, key, value)
log_text += "{}avg_spec_len: {}\n{}avg_text_len: {}\n{}step_time: {:.2f}\n{}loader_time: {:.2f}\n{}lr: {:.5f}"\ log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\
\n{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"\
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr) .format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
print(log_text, flush=True) print(log_text, flush=True)
def print_train_epoch(step, global_step, epoch, loss_dict):
pass
def print_eval_step():
pass
def check_config(c): def check_config(c):
@ -530,6 +523,7 @@ def check_config(c):
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) _check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# dataloading # dataloading
# pylint: disable=import-outside-toplevel
from TTS.utils.text import cleaners from TTS.utils.text import cleaners
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) _check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) _check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)

View File

@ -40,7 +40,7 @@ def plot_spectrogram(linear_output, audio, fig_size=(16, 10)):
return fig return fig
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=[8, 24]): def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24)):
if decoder_output is not None: if decoder_output is not None:
num_plot = 4 num_plot = 4
else: else: