diff --git a/layers/tacotron.py b/layers/tacotron.py index 0305bdec..0f2d7497 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -315,7 +315,7 @@ class Decoder(nn.Module): if t >= T_decoder: break else: - if t > 1 and is_end_of_frames(output.view(self.r, -1), alignment, self.eps): + if t > 1 and stop_token.sum().item() > 0.7: break elif t > self.max_decoder_steps: print(" !! Decoder stopped with 'max_decoder_steps'. \ @@ -326,7 +326,7 @@ class Decoder(nn.Module): alignments = torch.stack(alignments).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() stop_tokens = torch.stack(stop_tokens).transpose(0, 1) - return outputs, alignments, stop_tokens + return outputs, alignments, stop_tokens.squeeze() def is_end_of_frames(output, alignment, eps=0.01): # 0.2