mirror of https://github.com/coqui-ai/TTS.git
Loc sens attention
This commit is contained in:
parent
ddaf414434
commit
4e6596a8e1
|
@ -69,24 +69,25 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class AttentionRNNCell(nn.Module):
|
class AttentionRNNCell(nn.Module):
|
||||||
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
|
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model):
|
||||||
r"""
|
r"""
|
||||||
General Attention RNN wrapper
|
General Attention RNN wrapper
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
out_dim (int): context vector feature dimension.
|
out_dim (int): context vector feature dimension.
|
||||||
|
rnn_dim (int): rnn hidden state dimension.
|
||||||
annot_dim (int): annotation vector feature dimension.
|
annot_dim (int): annotation vector feature dimension.
|
||||||
memory_dim (int): memory vector (decoder autogression) feature dimension.
|
memory_dim (int): memory vector (decoder autogression) feature dimension.
|
||||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||||
"""
|
"""
|
||||||
super(AttentionRNNCell, self).__init__()
|
super(AttentionRNNCell, self).__init__()
|
||||||
self.align_model = align_model
|
self.align_model = align_model
|
||||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, rnn_dim)
|
||||||
# pick bahdanau or location sensitive attention
|
# pick bahdanau or location sensitive attention
|
||||||
if align_model == 'b':
|
if align_model == 'b':
|
||||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||||
if align_model == 'ls':
|
if align_model == 'ls':
|
||||||
self.alignment_model = LocationSensitiveAttention(annot_dim, out_dim, out_dim)
|
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(align_model))
|
||||||
|
@ -100,11 +101,10 @@ class AttentionRNNCell(nn.Module):
|
||||||
- context: (batch, dim)
|
- context: (batch, dim)
|
||||||
- rnn_state: (batch, out_dim)
|
- rnn_state: (batch, out_dim)
|
||||||
- annots: (batch, max_time, annot_dim)
|
- annots: (batch, max_time, annot_dim)
|
||||||
- atten: (batch, max_time)
|
- atten: (batch, 2, max_time)
|
||||||
- annot_lens: (batch,)
|
- annot_lens: (batch,)
|
||||||
"""
|
"""
|
||||||
# Concat input query and previous context context
|
# Concat input query and previous context context
|
||||||
print(context.shape)
|
|
||||||
rnn_input = torch.cat((memory, context), -1)
|
rnn_input = torch.cat((memory, context), -1)
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||||
|
|
|
@ -203,7 +203,8 @@ class Decoder(nn.Module):
|
||||||
# memory -> |Prenet| -> processed_memory
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
self.prenet = Prenet(memory_dim * r, out_features=[256, 128])
|
||||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||||
self.attention_rnn = AttentionRNNCell(128, in_features, 128, align_model='ls')
|
self.attention_rnn = AttentionRNNCell(out_dim=128, rnn_dim=256, annot_dim=in_features,
|
||||||
|
memory_dim=128, align_model='ls')
|
||||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
self.project_to_decoder_in = nn.Linear(256+in_features, 256)
|
||||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
|
|
Loading…
Reference in New Issue