mirror of https://github.com/coqui-ai/TTS.git
attentio update
This commit is contained in:
parent
d4f1ccd3ed
commit
ddaf414434
|
@ -5,11 +5,11 @@ from utils.generic_utils import sequence_mask
|
|||
|
||||
|
||||
class BahdanauAttention(nn.Module):
|
||||
def __init__(self, annot_dim, query_dim, hidden_dim):
|
||||
def __init__(self, annot_dim, query_dim, attn_dim):
|
||||
super(BahdanauAttention, self).__init__()
|
||||
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
||||
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
||||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
|
||||
def forward(self, annots, query):
|
||||
"""
|
||||
|
@ -33,8 +33,8 @@ class BahdanauAttention(nn.Module):
|
|||
class LocationSensitiveAttention(nn.Module):
|
||||
"""Location sensitive attention following
|
||||
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||
def __init__(self, annot_dim, query_dim, hidden_dim,
|
||||
kernel_size=7, filters=20):
|
||||
def __init__(self, annot_dim, query_dim, attn_dim,
|
||||
kernel_size=31, filters=32):
|
||||
super(LocationSensitiveAttention, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.filters = filters
|
||||
|
@ -42,10 +42,10 @@ class LocationSensitiveAttention(nn.Module):
|
|||
self.loc_conv = nn.Conv1d(1, filters,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=padding, bias=False)
|
||||
self.loc_linear = nn.Linear(filters, hidden_dim)
|
||||
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
||||
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
||||
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
|
||||
def forward(self, annot, query, loc):
|
||||
"""
|
||||
|
@ -104,6 +104,7 @@ class AttentionRNNCell(nn.Module):
|
|||
- annot_lens: (batch,)
|
||||
"""
|
||||
# Concat input query and previous context context
|
||||
print(context.shape)
|
||||
rnn_input = torch.cat((memory, context), -1)
|
||||
# Feed it to RNN
|
||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||
|
|
|
@ -203,7 +203,7 @@ class Decoder(nn.Module):
|
|||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNNCell(256, in_features, 128, align_model='ls')
|
||||
self.attention_rnn = AttentionRNNCell(128, in_features, 128, align_model='ls')
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
@ -248,7 +248,7 @@ class Decoder(nn.Module):
|
|||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))]
|
||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||
current_context_vec = inputs.data.new(B, 128).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
|
|
Loading…
Reference in New Issue