mirror of https://github.com/coqui-ai/TTS.git
Add location sens attention
This commit is contained in:
parent
7b9fd63649
commit
243204bc3e
|
@ -34,7 +34,8 @@ class BahdanauAttention(nn.Module):
|
||||||
class LocationSensitiveAttention(nn.Module):
|
class LocationSensitiveAttention(nn.Module):
|
||||||
"""Location sensitive attention following
|
"""Location sensitive attention following
|
||||||
https://arxiv.org/pdf/1506.07503.pdf"""
|
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||||
def __init__(self, annot_dim, out_dim, hidden_dim):
|
def __init__(self, annot_dim, query_dim, hidden_dim):
|
||||||
|
super(LocationSensitiveAttention, self).__init__()
|
||||||
loc_kernel_size = 31
|
loc_kernel_size = 31
|
||||||
loc_dim = 32
|
loc_dim = 32
|
||||||
padding = int((loc_kernel_size -1) / 2)
|
padding = int((loc_kernel_size -1) / 2)
|
||||||
|
@ -46,13 +47,16 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
||||||
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, query, annot, loc):
|
def forward(self, annot, query, loc):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- query: (batch, 1, dim) or (batch, dim)
|
|
||||||
- annots: (batch, max_time, dim)
|
- annots: (batch, max_time, dim)
|
||||||
|
- query: (batch, 1, dim) or (batch, dim)
|
||||||
- loc: (batch, 2, max_time)
|
- loc: (batch, 2, max_time)
|
||||||
"""
|
"""
|
||||||
|
if query.dim() == 2:
|
||||||
|
# insert time-axis for broadcasting
|
||||||
|
query = query.unsqueeze(1)
|
||||||
loc_conv = self.loc_conv(loc)
|
loc_conv = self.loc_conv(loc)
|
||||||
loc_conv = loc_conv.transpose(1, 2)
|
loc_conv = loc_conv.transpose(1, 2)
|
||||||
processed_loc = self.loc_linear(loc_conv)
|
processed_loc = self.loc_linear(loc_conv)
|
||||||
|
@ -68,8 +72,7 @@ class AttentionRNN(nn.Module):
|
||||||
def __init__(self, out_dim, annot_dim, memory_dim):
|
def __init__(self, out_dim, annot_dim, memory_dim):
|
||||||
super(AttentionRNN, self).__init__()
|
super(AttentionRNN, self).__init__()
|
||||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||||
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim, 3, out_dim)
|
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
|
||||||
self.score_mask_value = score_mask_value
|
|
||||||
|
|
||||||
def forward(self, memory, context, rnn_state, annotations,
|
def forward(self, memory, context, rnn_state, annotations,
|
||||||
attention_vec, mask=None, annotations_lengths=None):
|
attention_vec, mask=None, annotations_lengths=None):
|
||||||
|
@ -85,7 +88,7 @@ class AttentionRNN(nn.Module):
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
# e_{ij} = a(s_{i-1}, h_j)
|
# e_{ij} = a(s_{i-1}, h_j)
|
||||||
alignment = self.alignment_model(annotations, rnn_output)
|
alignment = self.alignment_model(annotations, rnn_output, attention_vec)
|
||||||
|
|
||||||
# TODO: needs recheck.
|
# TODO: needs recheck.
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
|
Loading…
Reference in New Issue