mirror of https://github.com/coqui-ai/TTS.git
Commenting the attention code
This commit is contained in:
parent
7acf4eab94
commit
d8c460442a
|
@ -70,6 +70,15 @@ class LocationSensitiveAttention(nn.Module):
|
||||||
|
|
||||||
class AttentionRNN(nn.Module):
|
class AttentionRNN(nn.Module):
|
||||||
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
|
def __init__(self, out_dim, annot_dim, memory_dim, align_model):
|
||||||
|
r"""
|
||||||
|
General Attention RNN wrapper
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_dim (int): context vector feature dimension.
|
||||||
|
annot_dim (int): annotation vector feature dimension.
|
||||||
|
memory_dim (int): memory vector (decoder autogression) feature dimension.
|
||||||
|
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||||
|
"""
|
||||||
super(AttentionRNN, self).__init__()
|
super(AttentionRNN, self).__init__()
|
||||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
|
||||||
# pick bahdanau or location sensitive attention
|
# pick bahdanau or location sensitive attention
|
||||||
|
|
Loading…
Reference in New Issue