diff --git a/layers/tacotron.py b/layers/tacotron.py index f95b1bc1..1dfa5a3c 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -434,7 +434,7 @@ class Decoder(nn.Module): if t >= T_decoder: break else: - if t > inputs.shape[1] / 4 and stop_token > 0.6: + if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): break elif t > self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps")