diff --git a/layers/tacotron.py b/layers/tacotron.py index 1f0d0be6..8915f385 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -395,7 +395,7 @@ class Decoder(nn.Module): return output, stop_token, self.attention_layer.attention_weights def _update_memory_queue(self, new_memory): - if self.memory_size > 0: + if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size: self.memory_input = torch.cat([ self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory