linter fixes

This commit is contained in:
erogol 2020-03-10 11:30:13 +01:00
parent e97bb45aba
commit 669a2e1d73
3 changed files with 30 additions and 34 deletions

View File

@ -179,6 +179,7 @@ class TacotronLoss(torch.nn.Module):
if c.ga_alpha > 0:
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
# stopnet loss
# pylint: disable=not-callable
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
def forward(self, postnet_output, decoder_output, mel_input, linear_input,

View File

@ -232,7 +232,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
'avg_postnet_loss': float(loss_dict['postnet_loss'].item()),
'avg_decoder_loss': float(loss_dict['decoder_loss'].item()),
'avg_stop_loss': loss_dict['stopnet_loss'].item()
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
if isinstance(loss_dict['stopnet_loss'], float) else float(loss_dict['stopnet_loss'].item()),
'avg_step_time': step_time,
'avg_loader_time': loader_time
}
@ -399,16 +399,16 @@ def evaluate(model, criterion, ap, global_step, epoch):
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
"StopLoss: {:.5f} - {:.5f} GALoss: {:.5f} - {:.5f} AlignScore: {:.4f} - {:.4f}"
.format(loss_dict['loss'].item(),
loss_dict['postnet_loss'].item(),
keep_avg['avg_postnet_loss'],
loss_dict['decoder_loss'].item(),
keep_avg['avg_decoder_loss'],
loss_dict['stopnet_loss'].item(),
keep_avg['avg_stop_loss'],
loss_dict['ga_loss'].item(),
keep_avg['avg_ga_loss'],
align_score, keep_avg['avg_align_score']),
flush=True)
loss_dict['postnet_loss'].item(),
keep_avg['avg_postnet_loss'],
loss_dict['decoder_loss'].item(),
keep_avg['avg_decoder_loss'],
loss_dict['stopnet_loss'].item(),
keep_avg['avg_stop_loss'],
loss_dict['ga_loss'].item(),
keep_avg['avg_ga_loss'],
align_score, keep_avg['avg_align_score']),
flush=Tr ue)
if args.rank == 0:
# Diagnostic visualizations

View File

@ -50,7 +50,7 @@ class AudioProcessor(object):
self.clip_norm = clip_norm
self.do_trim_silence = do_trim_silence
self.trim_db = trim_db
self.sound_norm = sound_norm
self.do_sound_norm = sound_norm
# setup stft parameters
if hop_length is None:
self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
@ -232,15 +232,24 @@ class AudioProcessor(object):
return librosa.effects.trim(
wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0]
def sound_norm(self, x):
@staticmethod
def sound_norm(x):
return x / abs(x).max() * 0.9
### save and load ###
### save and load ###
def load_wav(self, filename, sr=None):
if sr is None:
x, sr = sf.read(filename)
else:
x, sr = librosa.load(filename, sr=sr)
if self.do_trim_silence:
try:
x = self.trim_silence(x)
except ValueError:
print(f' [!] File cannot be trimmed for silence - {filename}')
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
if self.do_sound_norm:
x = self.sound_norm(x)
return x
def save_wav(self, wav, path):
@ -263,20 +272,6 @@ class AudioProcessor(object):
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
def load_wav(self, filename, sr=None):
if sr is None:
x, sr = sf.read(filename)
else:
x, sr = librosa.load(filename, sr=sr)
if self.do_trim_silence:
try:
x = self.trim_silence(x)
except ValueError:
print(f' [!] File cannot be trimmed for silence - {filename}')
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
if self.sound_norm:
x = x / abs(x).max() * 0.9
return x
@staticmethod
def encode_16bits(x):