update GMM attention calp max min

This commit is contained in:
Eren Golge 2019-11-11 11:30:23 +01:00
parent 6f3dd1b6ae
commit 1401a0db6b
2 changed files with 265 additions and 264 deletions

View File

@ -157,7 +157,7 @@ class GravesAttention(nn.Module):
k_t = gbk_t[:, 2, :] k_t = gbk_t[:, 2, :]
# attention GMM parameters # attention GMM parameters
inv_sig_t = torch.exp(-torch.clamp(b_t, min=-7, max=9)) # variance inv_sig_t = torch.exp(-torch.clamp(b_t, min=-6, max=9)) # variance
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
g_t = torch.softmax(g_t, dim=-1) * inv_sig_t + self.eps g_t = torch.softmax(g_t, dim=-1) * inv_sig_t + self.eps
@ -173,7 +173,7 @@ class GravesAttention(nn.Module):
phi_t = g_t * torch.exp(-0.5 * inv_sig_t * (mu_t_ - j)**2) phi_t = g_t * torch.exp(-0.5 * inv_sig_t * (mu_t_ - j)**2)
alpha_t = self.COEF * torch.sum(phi_t, 1) alpha_t = self.COEF * torch.sum(phi_t, 1)
if alpha_t.max() > 1e+2: if alpha_t.max() > 1e+3:
breakpoint() breakpoint()
# apply masking # apply masking

View File

@ -48,6 +48,7 @@ class AudioProcessor(object):
self.do_trim_silence = do_trim_silence self.do_trim_silence = do_trim_silence
self.sound_norm = sound_norm self.sound_norm = sound_norm
self.n_fft, self.hop_length, self.win_length = self._stft_parameters() self.n_fft, self.hop_length, self.win_length = self._stft_parameters()
assert min_level_db ~= 0.0, " [!] min_level_db is 0"
members = vars(self) members = vars(self)
for key, value in members.items(): for key, value in members.items():
print(" | > {}:{}".format(key, value)) print(" | > {}:{}".format(key, value))