Enalbe attention windowing and make in configurable at model level.

This commit is contained in:
Eren Golge 2019-01-13 19:09:44 +01:00
parent 9927664f27
commit ed1f648b83
3 changed files with 26 additions and 5 deletions

View File

@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module):
class AttentionRNNCell(nn.Module):
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model):
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False):
r"""
General Attention RNN wrapper
@ -110,10 +110,17 @@ class AttentionRNNCell(nn.Module):
annot_dim (int): annotation vector feature dimension.
memory_dim (int): memory vector (decoder output) feature dimension.
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
windowing (bool): attention windowing forcing monotonic attention.
It is only active in eval mode.
"""
super(AttentionRNNCell, self).__init__()
self.align_model = align_model
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
self.windowing = windowing
if self.windowing:
self.win_back = 1
self.win_front = 3
self.win_idx = None
# pick bahdanau or location sensitive attention
if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
@ -138,6 +145,7 @@ class AttentionRNNCell(nn.Module):
"""
if t == 0:
self.alignment_model.reset()
self.win_idx = 0
# Feed it to RNN
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
@ -151,6 +159,17 @@ class AttentionRNNCell(nn.Module):
if mask is not None:
mask = mask.view(memory.size(0), -1)
alignment.masked_fill_(1 - mask, -float("inf"))
# Windowing
if not self.training:
# print(" > Windowing active")
back_win = self.win_idx - self.win_back
front_win = self.win_idx + self.win_front
if back_win > 0:
alignment[:, :back_win] = -float("inf")
if front_win < memory.shape[1]:
alignment[:, front_win:] = -float("inf")
# Update the window
self.win_idx = torch.argmax(alignment,1).long()[0].item()
# Normalize context weight
# alignment = F.softmax(alignment, dim=-1)
# alignment = 5 * alignment

View File

@ -304,7 +304,7 @@ class Decoder(nn.Module):
r (int): number of outputs per time step.
"""
def __init__(self, in_features, memory_dim, r):
def __init__(self, in_features, memory_dim, r, attn_windowing):
super(Decoder, self).__init__()
self.r = r
self.in_features = in_features
@ -318,7 +318,8 @@ class Decoder(nn.Module):
rnn_dim=256,
annot_dim=in_features,
memory_dim=128,
align_model='ls')
align_model='ls',
windowing=attn_windowing)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state

View File

@ -12,7 +12,8 @@ class Tacotron(nn.Module):
linear_dim=1025,
mel_dim=80,
r=5,
padding_idx=None):
padding_idx=None,
attn_windowing=False):
super(Tacotron, self).__init__()
self.r = r
self.mel_dim = mel_dim
@ -24,7 +25,7 @@ class Tacotron(nn.Module):
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(embedding_dim)
self.decoder = Decoder(256, mel_dim, r)
self.decoder = Decoder(256, mel_dim, r, attn_windowing)
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),