mirror of https://github.com/coqui-ai/TTS.git
More comments for new layers
This commit is contained in:
parent
4127b66359
commit
256ed6307c
|
@ -231,7 +231,7 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- inputs: batch x time x encoder_out_dim
|
- inputs: batch x time x encoder_out_dim
|
||||||
- memory: batch x #mels_pecs x mel_spec_dim
|
- memory: batch x #mel_specs x mel_spec_dim
|
||||||
"""
|
"""
|
||||||
B = inputs.size(0)
|
B = inputs.size(0)
|
||||||
# Run greedy decoding if memory is None
|
# Run greedy decoding if memory is None
|
||||||
|
@ -308,6 +308,13 @@ class Decoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class StopNet(nn.Module):
|
class StopNet(nn.Module):
|
||||||
|
r"""
|
||||||
|
Predicting stop-token in decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
r (int): number of output frames of the network.
|
||||||
|
memory_dim (int): feature dimension for each output frame.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, r, memory_dim):
|
def __init__(self, r, memory_dim):
|
||||||
super(StopNet, self).__init__()
|
super(StopNet, self).__init__()
|
||||||
|
@ -317,6 +324,11 @@ class StopNet(nn.Module):
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
def forward(self, inputs, rnn_hidden):
|
def forward(self, inputs, rnn_hidden):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
inputs: network output tensor with r x memory_dim feature dimension.
|
||||||
|
rnn_hidden: hidden state of the RNN cell.
|
||||||
|
"""
|
||||||
rnn_hidden = self.rnn(inputs, rnn_hidden)
|
rnn_hidden = self.rnn(inputs, rnn_hidden)
|
||||||
outputs = self.relu(rnn_hidden)
|
outputs = self.relu(rnn_hidden)
|
||||||
outputs = self.linear(outputs)
|
outputs = self.linear(outputs)
|
||||||
|
|
Loading…
Reference in New Issue