mirror of https://github.com/coqui-ai/TTS.git
Add location sensitive attention
This commit is contained in:
parent
e6112f7b2d
commit
3348d462d2
|
@ -29,34 +29,50 @@ class BahdanauAttention(nn.Module):
|
|||
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
||||
def get_mask_from_lengths(inputs, inputs_lengths):
|
||||
"""Get mask tensor from list of length
|
||||
|
||||
Args:
|
||||
inputs: Tensor in size (batch, max_time, dim)
|
||||
inputs_lengths: array like
|
||||
"""
|
||||
mask = inputs.data.new(inputs.size(0), inputs.size(1)).byte().zero_()
|
||||
for idx, l in enumerate(inputs_lengths):
|
||||
mask[idx][:l] = 1
|
||||
return ~mask
|
||||
|
||||
|
||||
class LocationSensitiveAttention(nn.Module):
|
||||
"""Location sensitive attention following
|
||||
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||
def __init__(self, annot_dim, out_dim, hidden_dim):
|
||||
loc_kernel_size = 31
|
||||
loc_dim = 32
|
||||
padding = int((loc_kernel_size -1) / 2)
|
||||
self.loc_conv = nn.Conv1d(2, loc_dim,
|
||||
kernel_size=loc_kernel_size, stride=1,
|
||||
padding=padding, bias=False)
|
||||
self.loc_linear = nn.Linear(loc_dim, hidden_dim)
|
||||
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
||||
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
||||
|
||||
def forward(self, query, annot, loc):
|
||||
"""
|
||||
Shapes:
|
||||
- query: (batch, 1, dim) or (batch, dim)
|
||||
- annots: (batch, max_time, dim)
|
||||
- loc: (batch, 2, max_time)
|
||||
"""
|
||||
loc_conv = self.loc_conv(loc)
|
||||
loc_conv = loc_conv.transpose(1, 2)
|
||||
processed_loc = self.loc_linear(loc_conv)
|
||||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annot)
|
||||
alignment = self.v(nn.functional.tanh(
|
||||
processed_query + processed_annots + processed_loc))
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
||||
class AttentionRNN(nn.Module):
|
||||
def __init__(self, out_dim, annot_dim, memory_dim,
|
||||
score_mask_value=-float("inf")):
|
||||
def __init__(self, out_dim, annot_dim, memory_dim):
|
||||
super(AttentionRNN, self).__init__()
|
||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim, 3, out_dim)
|
||||
self.score_mask_value = score_mask_value
|
||||
|
||||
def forward(self, memory, context, rnn_state, annotations,
|
||||
mask=None, annotations_lengths=None):
|
||||
|
||||
if annotations_lengths is not None and mask is None:
|
||||
mask = get_mask_from_lengths(annotations, annotations_lengths)
|
||||
attention_vec, mask=None, annotations_lengths=None):
|
||||
|
||||
# Concat input query and previous context context
|
||||
rnn_input = torch.cat((memory, context), -1)
|
||||
|
@ -85,3 +101,4 @@ class AttentionRNN(nn.Module):
|
|||
context = torch.bmm(alignment.unsqueeze(1), annotations)
|
||||
context = context.squeeze(1)
|
||||
return rnn_output, context, alignment
|
||||
|
||||
|
|
Loading…
Reference in New Issue