mirror of https://github.com/coqui-ai/TTS.git
Commenting the attention code
This commit is contained in:
parent
7acf4eab94
commit
d8c460442a
|
@ -70,6 +70,15 @@ class LocationSensitiveAttention(nn.Module):
|
|||
|
||||
class AttentionRNN(nn.Module):
|
||||
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
|
||||
r"""
|
||||
General Attention RNN wrapper
|
||||
|
||||
Args:
|
||||
out_dim (int): context vector feature dimension.
|
||||
annot_dim (int): annotation vector feature dimension.
|
||||
memory_dim (int): memory vector (decoder autogression) feature dimension.
|
||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||
"""
|
||||
super(AttentionRNN, self).__init__()
|
||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||
# pick bahdanau or location sensitive attention
|
||||
|
|
Loading…
Reference in New Issue