diff --git a/layers/tacotron.py b/layers/tacotron.py index 4d23835f..756d9b26 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -232,6 +232,7 @@ class Decoder(nn.Module): [nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) + self.stopnet = nn.Sequential(nn.Dropout(0.2), nn.Linear(memory_dim * self.r, 1), nn.Sigmoid()) def forward(self, inputs, memory=None): """ @@ -272,6 +273,7 @@ class Decoder(nn.Module): memory = memory.transpose(0, 1) outputs = [] alignments = [] + stop_tokens = [] t = 0 memory_input = initial_memory while True: @@ -297,8 +299,11 @@ class Decoder(nn.Module): output = decoder_input # predict mel vectors from decoder vectors output = self.proj_to_mel(output) + # predict stop token + stop_token = self.stopnet(output) outputs += [output] alignments += [alignment] + stop_tokens += stop_token t += 1 if (not greedy and self.training) or (greedy and memory is not None): if t >= T_decoder: @@ -314,7 +319,8 @@ class Decoder(nn.Module): # Back to batch first alignments = torch.stack(alignments).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - return outputs, alignments + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + return outputs, alignments, stop_tokens def is_end_of_frames(output, alignment, eps=0.05): # 0.2