diff --git a/layers/tacotron.py b/layers/tacotron.py index 7df5b1e8..96d187e1 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -232,7 +232,12 @@ class Decoder(nn.Module): [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) - self.stopnet = nn.Sequential(nn.Dropout(0.2), nn.Linear(memory_dim * self.r, 1), nn.Sigmoid()) + self.stopnet = nn.Sequential(nn.Linear(memory_dim * self.r, memory_dim), + nn.ReLU(), + nn.Linear(memory_dim, memory_dim * self.r), + nn.ReLU(), + nn.Linear(memory_dim * self.r, 1), + nn.Sigmoid()) def forward(self, inputs, memory=None): """