mirror of https://github.com/coqui-ai/TTS.git
graves attention as in melnet paper
This commit is contained in:
parent
5e148038be
commit
e5bf2719bd
|
@ -131,8 +131,8 @@ class GravesAttention(nn.Module):
|
||||||
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10)
|
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10)
|
||||||
|
|
||||||
def init_states(self, inputs):
|
def init_states(self, inputs):
|
||||||
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
|
||||||
self.J = torch.arange(0, inputs.shape[1]+1).to(inputs.device) + 0.5
|
self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5
|
||||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||||
|
|
||||||
|
@ -160,24 +160,25 @@ class GravesAttention(nn.Module):
|
||||||
|
|
||||||
# attention GMM parameters
|
# attention GMM parameters
|
||||||
sig_t = torch.nn.functional.softplus(b_t) + self.eps
|
sig_t = torch.nn.functional.softplus(b_t) + self.eps
|
||||||
|
|
||||||
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
||||||
g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps
|
g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps
|
||||||
|
|
||||||
j = self.J[:inputs.size(1)+1]
|
j = self.J[:inputs.size(1)+1]
|
||||||
|
|
||||||
# attention weights
|
# attention weights
|
||||||
phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2))
|
phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.exp((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1))))
|
||||||
|
|
||||||
# discritize attention weights
|
# discritize attention weights
|
||||||
alpha_t = self.COEF * torch.sum(phi_t, 1)
|
alpha_t = torch.sum(phi_t, 1)
|
||||||
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
||||||
|
alpha_t[alpha_t == 0] = 1e-8
|
||||||
|
|
||||||
# apply masking
|
# apply masking
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
alpha_t.data.masked_fill_(~mask, self._mask_value)
|
alpha_t.data.masked_fill_(~mask, self._mask_value)
|
||||||
|
|
||||||
context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1)
|
||||||
|
# for better visualization
|
||||||
|
# self.attention_weights = torch.clamp(alpha_t, min=0)
|
||||||
self.attention_weights = alpha_t
|
self.attention_weights = alpha_t
|
||||||
self.mu_prev = mu_t
|
self.mu_prev = mu_t
|
||||||
return context
|
return context
|
||||||
|
|
|
@ -1,11 +1,18 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
def alignment_diagonal_score(alignments):
|
|
||||||
|
def alignment_diagonal_score(alignments, binary=False):
|
||||||
"""
|
"""
|
||||||
Compute how diagonal alignment predictions are. It is useful
|
Compute how diagonal alignment predictions are. It is useful
|
||||||
to measure the alignment consistency of a model
|
to measure the alignment consistency of a model
|
||||||
Args:
|
Args:
|
||||||
alignments (torch.Tensor): batch of alignments.
|
alignments (torch.Tensor): batch of alignments.
|
||||||
|
binary (bool): if True, ignore scores and consider attention
|
||||||
|
as a binary mask.
|
||||||
Shape:
|
Shape:
|
||||||
alignments : batch x decoder_steps x encoder_steps
|
alignments : batch x decoder_steps x encoder_steps
|
||||||
"""
|
"""
|
||||||
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0).item()
|
maxs = alignments.max(dim=1)[0]
|
||||||
|
if binary:
|
||||||
|
maxs[maxs > 0] = 1
|
||||||
|
return maxs.mean(dim=1).mean(dim=0).item()
|
||||||
|
|
Loading…
Reference in New Issue