diff --git a/layers/tacotron.py b/layers/tacotron.py index 1e673cca..5256d215 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -331,7 +331,13 @@ class Decoder(nn.Module): class StopNet(nn.Module): def __init__(self, r, memory_dim): - """Predicts the stop token to stop the decoder at testing time""" + r""" + Predicts the stop token to stop the decoder at testing time + + Args: + r (int): number of network output frames. + memory_dim (int): single feature dim of a single network output frame. + """ super(StopNet, self).__init__() self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r) self.relu = nn.ReLU()