mirror of https://github.com/coqui-ai/TTS.git
attentio update
This commit is contained in:
parent
766be91a09
commit
d6947c0f13
|
@ -5,11 +5,11 @@ from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
class BahdanauAttention(nn.Module):
|
class BahdanauAttention(nn.Module):
|
||||||
def __init__(self, annot_dim, query_dim, hidden_dim):
|
def __init__(self, annot_dim, query_dim, attn_dim):
|
||||||
super(BahdanauAttention, self).__init__()
|
super(BahdanauAttention, self).__init__()
|
||||||
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||||
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||||
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, annots, query):
|
def forward(self, annots, query):
|
||||||
"""
|
"""
|
||||||
|
@ -33,8 +33,8 @@ class BahdanauAttention(nn.Module):
|
||||||
class LocationSensitiveAttention(nn.Module):
|
class LocationSensitiveAttention(nn.Module):
|
||||||
"""Location sensitive attention following
|
"""Location sensitive attention following
|
||||||
https://arxiv.org/pdf/1506.07503.pdf"""
|
https://arxiv.org/pdf/1506.07503.pdf"""
|
||||||
def __init__(self, annot_dim, query_dim, hidden_dim,
|
def __init__(self, annot_dim, query_dim, attn_dim,
|
||||||
kernel_size=7, filters=20):
|
kernel_size=31, filters=32):
|
||||||
super(LocationSensitiveAttention, self).__init__()
|
super(LocationSensitiveAttention, self).__init__()
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
|
@ -42,10 +42,10 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
self.loc_conv = nn.Conv1d(1, filters,
|
self.loc_conv = nn.Conv1d(1, filters,
|
||||||
kernel_size=kernel_size, stride=1,
|
kernel_size=kernel_size, stride=1,
|
||||||
padding=padding, bias=False)
|
padding=padding, bias=False)
|
||||||
self.loc_linear = nn.Linear(filters, hidden_dim)
|
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||||
self.query_layer = nn.Linear(query_dim, hidden_dim, bias=True)
|
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||||
self.annot_layer = nn.Linear(annot_dim, hidden_dim, bias=True)
|
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||||
self.v = nn.Linear(hidden_dim, 1, bias=False)
|
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, annot, query, loc):
|
def forward(self, annot, query, loc):
|
||||||
"""
|
"""
|
||||||
|
@ -104,6 +104,7 @@ class AttentionRNNCell(nn.Module):
|
||||||
- annot_lens: (batch,)
|
- annot_lens: (batch,)
|
||||||
"""
|
"""
|
||||||
# Concat input query and previous context context
|
# Concat input query and previous context context
|
||||||
|
print(context.shape)
|
||||||
rnn_input = torch.cat((memory, context), -1)
|
rnn_input = torch.cat((memory, context), -1)
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||||
|
|
|
@ -203,7 +203,7 @@ class Decoder(nn.Module):
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = AttentionRNNCell(256, in_features, 128, align_model='ls')
|
self.attention_rnn = AttentionRNNCell(128, in_features, 128, align_model='ls')
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
@ -248,7 +248,7 @@ class Decoder(nn.Module):
|
||||||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
current_context_vec = inputs.data.new(B, 128).zero_()
|
||||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||||
# attention states
|
# attention states
|
||||||
attention = inputs.data.new(B, T).zero_()
|
attention = inputs.data.new(B, T).zero_()
|
||||||
|
|
Loading…
Reference in New Issue