More comments for new layers

This commit is contained in:
Eren Golge 2018-05-25 03:25:26 -07:00
parent 4127b66359
commit 256ed6307c
1 changed files with 13 additions and 1 deletions

View File

@ -231,7 +231,7 @@ class Decoder(nn.Module):
Shapes:
- inputs: batch x time x encoder_out_dim
- memory: batch x #mels_pecs x mel_spec_dim
- memory: batch x #mel_specs x mel_spec_dim
"""
B = inputs.size(0)
# Run greedy decoding if memory is None
@ -308,6 +308,13 @@ class Decoder(nn.Module):
class StopNet(nn.Module):
r"""
Predicting stop-token in decoder.
Args:
r (int): number of output frames of the network.
memory_dim (int): feature dimension for each output frame.
"""
def __init__(self, r, memory_dim):
super(StopNet, self).__init__()
@ -317,6 +324,11 @@ class StopNet(nn.Module):
self.sigmoid = nn.Sigmoid()
def forward(self, inputs, rnn_hidden):
"""
Args:
inputs: network output tensor with r x memory_dim feature dimension.
rnn_hidden: hidden state of the RNN cell.
"""
rnn_hidden = self.rnn(inputs, rnn_hidden)
outputs = self.relu(rnn_hidden)
outputs = self.linear(outputs)