mirror of https://github.com/coqui-ai/TTS.git
simplified code for fwd attn
This commit is contained in:
parent
a6118564d5
commit
40f56f9b00
|
@ -201,17 +201,17 @@ class Attention(nn.Module):
|
||||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||||
return attention
|
return attention
|
||||||
|
|
||||||
def apply_forward_attention(self, inputs, alignment, query):
|
def apply_forward_attention(self, alignment):
|
||||||
# forward attention
|
# forward attention
|
||||||
prev_alpha = F.pad(self.alpha[:, :-1].clone().to(inputs.device),
|
fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device),
|
||||||
(1, 0, 0, 0))
|
(1, 0, 0, 0))
|
||||||
# compute transition potentials
|
# compute transition potentials
|
||||||
alpha = ((1 - self.u) * self.alpha
|
alpha = ((1 - self.u) * self.alpha
|
||||||
+ self.u * prev_alpha
|
+ self.u * fwd_shifted_alpha
|
||||||
+ 1e-8) * alignment
|
+ 1e-8) * alignment
|
||||||
# force incremental alignment
|
# force incremental alignment
|
||||||
if not self.training and self.forward_attn_mask:
|
if not self.training and self.forward_attn_mask:
|
||||||
_, n = prev_alpha.max(1)
|
_, n = fwd_shifted_alpha.max(1)
|
||||||
val, n2 = alpha.max(1)
|
val, n2 = alpha.max(1)
|
||||||
for b in range(alignment.shape[0]):
|
for b in range(alignment.shape[0]):
|
||||||
alpha[b, n[b] + 3:] = 0
|
alpha[b, n[b] + 3:] = 0
|
||||||
|
@ -221,16 +221,9 @@ class Attention(nn.Module):
|
||||||
alpha[b,
|
alpha[b,
|
||||||
(n[b] - 2
|
(n[b] - 2
|
||||||
)] = 0.01 * val[b] # smoothing factor for the prev step
|
)] = 0.01 * val[b] # smoothing factor for the prev step
|
||||||
# compute attention weights
|
# renormalize attention weights
|
||||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
alpha = alpha / alpha.sum(dim=1, keepdim=True)
|
||||||
# compute context
|
return alpha
|
||||||
context = torch.bmm(self.alpha.unsqueeze(1), inputs)
|
|
||||||
context = context.squeeze(1)
|
|
||||||
# compute transition agent
|
|
||||||
if self.trans_agent:
|
|
||||||
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
|
||||||
self.u = torch.sigmoid(self.ta(ta_input))
|
|
||||||
return context, self.alpha
|
|
||||||
|
|
||||||
def forward(self, query, inputs, processed_inputs, mask):
|
def forward(self, query, inputs, processed_inputs, mask):
|
||||||
if self.location_attention:
|
if self.location_attention:
|
||||||
|
@ -254,15 +247,20 @@ class Attention(nn.Module):
|
||||||
attention).sum(
|
attention).sum(
|
||||||
dim=1, keepdim=True)
|
dim=1, keepdim=True)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown value for attention norm type")
|
raise ValueError("Unknown value for attention norm type")
|
||||||
if self.location_attention:
|
if self.location_attention:
|
||||||
self.update_location_attention(alignment)
|
self.update_location_attention(alignment)
|
||||||
# apply forward attention if enabled
|
# apply forward attention if enabled
|
||||||
if self.forward_attn:
|
if self.forward_attn:
|
||||||
context, self.attention_weights = self.apply_forward_attention(
|
alignment = self.apply_forward_attention(alignment)
|
||||||
inputs, alignment, query)
|
self.alpha = alignment
|
||||||
else:
|
|
||||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
self.attention_weights = alignment
|
self.attention_weights = alignment
|
||||||
|
|
||||||
|
# compute transition agent
|
||||||
|
if self.forward_attn and self.trans_agent:
|
||||||
|
ta_input = torch.cat([context, query.squeeze(1)], dim=-1)
|
||||||
|
self.u = torch.sigmoid(self.ta(ta_input))
|
||||||
return context
|
return context
|
||||||
|
|
Loading…
Reference in New Issue