mirror of https://github.com/coqui-ai/TTS.git
More comments for new layers
This commit is contained in:
parent
4127b66359
commit
256ed6307c
|
@ -231,7 +231,7 @@ class Decoder(nn.Module):
|
|||
|
||||
Shapes:
|
||||
- inputs: batch x time x encoder_out_dim
|
||||
- memory: batch x #mels_pecs x mel_spec_dim
|
||||
- memory: batch x #mel_specs x mel_spec_dim
|
||||
"""
|
||||
B = inputs.size(0)
|
||||
# Run greedy decoding if memory is None
|
||||
|
@ -308,6 +308,13 @@ class Decoder(nn.Module):
|
|||
|
||||
|
||||
class StopNet(nn.Module):
|
||||
r"""
|
||||
Predicting stop-token in decoder.
|
||||
|
||||
Args:
|
||||
r (int): number of output frames of the network.
|
||||
memory_dim (int): feature dimension for each output frame.
|
||||
"""
|
||||
|
||||
def __init__(self, r, memory_dim):
|
||||
super(StopNet, self).__init__()
|
||||
|
@ -317,6 +324,11 @@ class StopNet(nn.Module):
|
|||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, inputs, rnn_hidden):
|
||||
"""
|
||||
Args:
|
||||
inputs: network output tensor with r x memory_dim feature dimension.
|
||||
rnn_hidden: hidden state of the RNN cell.
|
||||
"""
|
||||
rnn_hidden = self.rnn(inputs, rnn_hidden)
|
||||
outputs = self.relu(rnn_hidden)
|
||||
outputs = self.linear(outputs)
|
||||
|
|
Loading…
Reference in New Issue