mirror of https://github.com/coqui-ai/TTS.git
pylint fix
This commit is contained in:
parent
3673cc1e30
commit
f63bce89f6
|
@ -224,7 +224,7 @@ class MyDataset(Dataset):
|
||||||
mel_lengths = torch.LongTensor(mel_lengths)
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
|
|
||||||
# compute linear spectrogram
|
# compute linear spectrogram
|
||||||
if self.compute_linear_spec:
|
if self.compute_linear_spec:
|
||||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||||
|
|
|
@ -13,6 +13,7 @@ os.makedirs(OUT_PATH, exist_ok=True)
|
||||||
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
|
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
class TestAudio(unittest.TestCase):
|
class TestAudio(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(TestAudio, self).__init__(*args, **kwargs)
|
super(TestAudio, self).__init__(*args, **kwargs)
|
||||||
|
@ -165,7 +166,7 @@ class TestAudio(unittest.TestCase):
|
||||||
|
|
||||||
self.ap.signal_norm = False
|
self.ap.signal_norm = False
|
||||||
self.ap.preemphasis = 0.0
|
self.ap.preemphasis = 0.0
|
||||||
|
|
||||||
# test scaler forward and backward transforms
|
# test scaler forward and backward transforms
|
||||||
wav = self.ap.load_wav(WAV_FILE)
|
wav = self.ap.load_wav(WAV_FILE)
|
||||||
mel_reference = self.ap.melspectrogram(wav)
|
mel_reference = self.ap.melspectrogram(wav)
|
||||||
|
|
9
train.py
9
train.py
|
@ -7,7 +7,6 @@ import traceback
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from TTS.datasets.TTSDataset import MyDataset
|
from TTS.datasets.TTSDataset import MyDataset
|
||||||
|
@ -20,7 +19,7 @@ from TTS.utils.generic_utils import (
|
||||||
get_git_branch, load_config, remove_experiment_folder, save_best_model,
|
get_git_branch, load_config, remove_experiment_folder, save_best_model,
|
||||||
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
|
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
|
||||||
setup_model, gradual_training_scheduler, KeepAverage,
|
setup_model, gradual_training_scheduler, KeepAverage,
|
||||||
set_weight_decay, check_config, print_train_step)
|
set_weight_decay, check_config)
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||||
|
@ -215,7 +214,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
||||||
update_train_values = {
|
update_train_values = {
|
||||||
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
|
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
|
||||||
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
|
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
|
||||||
'avg_stopnet_loss': loss_dict['stopnet_loss'].item()
|
'avg_stopnet_loss': loss_dict['stopnet_loss'].item() \
|
||||||
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
|
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
|
||||||
'avg_step_time': step_time,
|
'avg_step_time': step_time,
|
||||||
'avg_loader_time': loader_time
|
'avg_loader_time': loader_time
|
||||||
|
@ -591,8 +590,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
print("\n > Number of output frames:", model.decoder.r)
|
print("\n > Number of output frames:", model.decoder.r)
|
||||||
|
|
||||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||||
optimizer_st, scheduler, ap,
|
optimizer_st, scheduler, ap,
|
||||||
global_step, epoch)
|
global_step, epoch)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||||
|
|
|
@ -72,7 +72,7 @@ class AudioProcessor(object):
|
||||||
# setup scaler
|
# setup scaler
|
||||||
if stats_path:
|
if stats_path:
|
||||||
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path)
|
||||||
self.setup_scaler(mel_mean, mel_std, linear_mean,linear_std)
|
self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std)
|
||||||
self.signal_norm = True
|
self.signal_norm = True
|
||||||
self.max_norm = None
|
self.max_norm = None
|
||||||
self.clip_norm = None
|
self.clip_norm = None
|
||||||
|
@ -107,7 +107,7 @@ class AudioProcessor(object):
|
||||||
# mean-var scaling
|
# mean-var scaling
|
||||||
if hasattr(self, 'mel_scaler'):
|
if hasattr(self, 'mel_scaler'):
|
||||||
if S.shape[0] == self.num_mels:
|
if S.shape[0] == self.num_mels:
|
||||||
return self.mel_scaler.transform(S.T).T
|
return self.mel_scaler.transform(S.T).T
|
||||||
elif S.shape[0] == self.n_fft / 2:
|
elif S.shape[0] == self.n_fft / 2:
|
||||||
return self.linear_scaler.transform(S.T).T
|
return self.linear_scaler.transform(S.T).T
|
||||||
else:
|
else:
|
||||||
|
@ -136,7 +136,7 @@ class AudioProcessor(object):
|
||||||
# mean-var scaling
|
# mean-var scaling
|
||||||
if hasattr(self, 'mel_scaler'):
|
if hasattr(self, 'mel_scaler'):
|
||||||
if S_denorm.shape[0] == self.num_mels:
|
if S_denorm.shape[0] == self.num_mels:
|
||||||
return self.mel_scaler.inverse_transform(S_denorm.T).T
|
return self.mel_scaler.inverse_transform(S_denorm.T).T
|
||||||
elif S_denorm.shape[0] == self.n_fft / 2:
|
elif S_denorm.shape[0] == self.n_fft / 2:
|
||||||
return self.linear_scaler.inverse_transform(S_denorm.T).T
|
return self.linear_scaler.inverse_transform(S_denorm.T).T
|
||||||
else:
|
else:
|
||||||
|
@ -168,10 +168,10 @@ class AudioProcessor(object):
|
||||||
for key in stats_config.keys():
|
for key in stats_config.keys():
|
||||||
if key in skip_parameters:
|
if key in skip_parameters:
|
||||||
continue
|
continue
|
||||||
assert stats_config[key] == self.__dict__[
|
assert stats_config[key] == self.__dict__[key],\
|
||||||
key], f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||||
|
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
def setup_scaler(self, mel_mean, mel_std, linear_mean, linear_std):
|
def setup_scaler(self, mel_mean, mel_std, linear_mean, linear_std):
|
||||||
self.mel_scaler = StandardScaler()
|
self.mel_scaler = StandardScaler()
|
||||||
|
@ -180,9 +180,11 @@ class AudioProcessor(object):
|
||||||
self.linear_scaler.set_stats(linear_mean, linear_std)
|
self.linear_scaler.set_stats(linear_mean, linear_std)
|
||||||
|
|
||||||
### DB and AMP conversion ###
|
### DB and AMP conversion ###
|
||||||
|
# pylint: disable=no-self-use
|
||||||
def _amp_to_db(self, x):
|
def _amp_to_db(self, x):
|
||||||
return 20 * np.log10(np.maximum(1e-5, x))
|
return 20 * np.log10(np.maximum(1e-5, x))
|
||||||
|
|
||||||
|
# pylint: disable=no-self-use
|
||||||
def _db_to_amp(self, x):
|
def _db_to_amp(self, x):
|
||||||
return np.power(10.0, x * 0.05)
|
return np.power(10.0, x * 0.05)
|
||||||
|
|
||||||
|
@ -269,15 +271,14 @@ class AudioProcessor(object):
|
||||||
y = self._istft(S_complex * angles)
|
y = self._istft(S_complex * angles)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def compute_stft_paddings(self,x, pad_sides=1):
|
def compute_stft_paddings(self, x, pad_sides=1):
|
||||||
'''compute right padding (final frame) or both sides padding (first and final frames)
|
'''compute right padding (final frame) or both sides padding (first and final frames)
|
||||||
'''
|
'''
|
||||||
assert pad_sides in (1, 2)
|
assert pad_sides in (1, 2)
|
||||||
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
|
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
|
||||||
if pad_sides == 1:
|
if pad_sides == 1:
|
||||||
return 0, pad
|
return 0, pad
|
||||||
else:
|
return pad // 2, pad // 2 + pad % 2
|
||||||
return pad // 2, pad // 2 + pad % 2
|
|
||||||
|
|
||||||
### Audio Processing ###
|
### Audio Processing ###
|
||||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||||
|
|
|
@ -22,6 +22,7 @@ class ConsoleLogger():
|
||||||
self.old_epoch_loss_dict = None
|
self.old_epoch_loss_dict = None
|
||||||
self.old_eval_loss_dict = None
|
self.old_eval_loss_dict = None
|
||||||
|
|
||||||
|
# pylint: disable=no-self-use
|
||||||
def get_time(self):
|
def get_time(self):
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
return now.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
@ -47,10 +48,11 @@ class ConsoleLogger():
|
||||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
||||||
else:
|
else:
|
||||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||||
log_text += "{}avg_spec_len: {}\n{}avg_text_len: {}\n{}step_time: {:.2f}\n{}loader_time: {:.2f}\n{}lr: {:.5f}"\
|
log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\n{indent}\
|
||||||
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
|
step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"
|
||||||
print(log_text, flush=True)
|
print(log_text, flush=True)
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
def print_train_epoch_end(self, global_step, epoch, epoch_time,
|
def print_train_epoch_end(self, global_step, epoch, epoch_time,
|
||||||
print_dict):
|
print_dict):
|
||||||
indent = " | > "
|
indent = " | > "
|
||||||
|
|
|
@ -369,7 +369,7 @@ class KeepAverage():
|
||||||
return self.avg_values[key]
|
return self.avg_values[key]
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return self.avg_values.items()
|
return self.avg_values.items()
|
||||||
|
|
||||||
def add_value(self, name, init_val=0, init_iter=0):
|
def add_value(self, name, init_val=0, init_iter=0):
|
||||||
self.avg_values[name] = init_val
|
self.avg_values[name] = init_val
|
||||||
|
@ -412,7 +412,7 @@ def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restric
|
||||||
|
|
||||||
tcolors = AttrDict({
|
tcolors = AttrDict({
|
||||||
'OKBLUE': '\033[94m',
|
'OKBLUE': '\033[94m',
|
||||||
'HEADER': '\033[95m',
|
'HEADER': '\033[95m',
|
||||||
'OKGREEN': '\033[92m',
|
'OKGREEN': '\033[92m',
|
||||||
'WARNING': '\033[93m',
|
'WARNING': '\033[93m',
|
||||||
'FAIL': '\033[91m',
|
'FAIL': '\033[91m',
|
||||||
|
@ -428,17 +428,10 @@ def print_train_step(batch_steps, step, global_step, avg_spec_length, avg_text_l
|
||||||
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
|
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
|
||||||
for key, value in print_dict.items():
|
for key, value in print_dict.items():
|
||||||
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
|
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
|
||||||
log_text += "{}avg_spec_len: {}\n{}avg_text_len: {}\n{}step_time: {:.2f}\n{}loader_time: {:.2f}\n{}lr: {:.5f}"\
|
log_text += f"{indent}avg_spec_len: {avg_spec_length}\n{indent}avg_text_len: {avg_text_length}\
|
||||||
|
\n{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"\
|
||||||
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
|
.format(indent, avg_spec_length, indent, avg_text_length, indent, step_time, indent, loader_time, indent, lr)
|
||||||
print(log_text, flush=True)
|
print(log_text, flush=True)
|
||||||
|
|
||||||
|
|
||||||
def print_train_epoch(step, global_step, epoch, loss_dict):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def print_eval_step():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def check_config(c):
|
def check_config(c):
|
||||||
|
@ -530,6 +523,7 @@ def check_config(c):
|
||||||
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||||
|
|
||||||
# dataloading
|
# dataloading
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
from TTS.utils.text import cleaners
|
from TTS.utils.text import cleaners
|
||||||
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||||
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||||
|
|
|
@ -40,7 +40,7 @@ def plot_spectrogram(linear_output, audio, fig_size=(16, 10)):
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=[8, 24]):
|
def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, decoder_output=None, output_path=None, figsize=(8, 24)):
|
||||||
if decoder_output is not None:
|
if decoder_output is not None:
|
||||||
num_plot = 4
|
num_plot = 4
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue