diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 2fa6d06f..0826ccc6 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -251,6 +251,7 @@ class Attention(nn.Module): else: context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) + self.attention_weights = alignment return context