mirror of https://github.com/coqui-ai/TTS.git
simplified code for fwd attn
This commit is contained in:
parent
a6118564d5
commit
40f56f9b00
|
@ -201,17 +201,17 @@ class Attention(nn.Module):
|
|||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
return attention
|
||||
|
||||
def apply_forward_attention(self, inputs, alignment, query):
|
||||
def apply_forward_attention(self, alignment):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device),
|
||||
(1, 0, 0, 0))
|
||||
fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device),
|
||||
(1, 0, 0, 0))
|
||||
# compute transition potentials
|
||||
alpha = ((1 - self.u) * self.alpha
|
||||
+ self.u * prev_alpha
|
||||
+ self.u * fwd_shifted_alpha
|
||||
+ 1e-8) * alignment
|
||||
# force incremental alignment
|
||||
if not self.training and self.forward_attn_mask:
|
||||
_, n = prev_alpha.max(1)
|
||||
_, n = fwd_shifted_alpha.max(1)
|
||||
val, n2 = alpha.max(1)
|
||||
for b in range(alignment.shape[0]):
|
||||
alpha[b, n[b] + 3:] = 0
|
||||
|
@ -221,16 +221,9 @@ class Attention(nn.Module):
|
|||
alpha[b,
|
||||
(n[b] - 2
|
||||
)] = 0.01 * val[b] # smoothing factor for the prev step
|
||||
# compute attention weights
|
||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
# compute transition agent
|
||||
if self.trans_agent:
|
||||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context, self.alpha
|
||||
# renormalize attention weights
|
||||
alpha = alpha / alpha.sum(dim=1, keepdim=True)
|
||||
return alpha
|
||||
|
||||
def forward(self, query, inputs, processed_inputs, mask):
|
||||
if self.location_attention:
|
||||
|
@ -254,15 +247,20 @@ class Attention(nn.Module):
|
|||
attention).sum(
|
||||
dim=1, keepdim=True)
|
||||
else:
|
||||
raise RuntimeError("Unknown value for attention norm type")
|
||||
raise ValueError("Unknown value for attention norm type")
|
||||
if self.location_attention:
|
||||
self.update_location_attention(alignment)
|
||||
# apply forward attention if enabled
|
||||
if self.forward_attn:
|
||||
context, self.attention_weights = self.apply_forward_attention(
|
||||
inputs, alignment, query)
|
||||
else:
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
self.attention_weights = alignment
|
||||
alignment = self.apply_forward_attention(alignment)
|
||||
self.alpha = alignment
|
||||
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
self.attention_weights = alignment
|
||||
|
||||
# compute transition agent
|
||||
if self.forward_attn and self.trans_agent:
|
||||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context
|
||||
|
|
Loading…
Reference in New Issue