mirror of https://github.com/coqui-ai/TTS.git
undo loc attn after fwd attn
This commit is contained in:
parent
f3dac0aa84
commit
ab42396fbf
|
@ -248,14 +248,15 @@ class Attention(nn.Module):
|
||||||
dim=1, keepdim=True)
|
dim=1, keepdim=True)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown value for attention norm type")
|
raise ValueError("Unknown value for attention norm type")
|
||||||
|
|
||||||
|
if self.location_attention:
|
||||||
|
self.update_location_attention(alignment)
|
||||||
|
|
||||||
# apply forward attention if enabled
|
# apply forward attention if enabled
|
||||||
if self.forward_attn:
|
if self.forward_attn:
|
||||||
alignment = self.apply_forward_attention(alignment)
|
alignment = self.apply_forward_attention(alignment)
|
||||||
self.alpha = alignment
|
self.alpha = alignment
|
||||||
|
|
||||||
if self.location_attention:
|
|
||||||
self.update_location_attention(alignment)
|
|
||||||
|
|
||||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
self.attention_weights = alignment
|
self.attention_weights = alignment
|
||||||
|
|
Loading…
Reference in New Issue