mirror of https://github.com/coqui-ai/TTS.git
fix for 2 dim memory tensor
This commit is contained in:
parent
e085c4757d
commit
9a2bd7f9af
|
@ -183,6 +183,9 @@ class Decoder(nn.Module):
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def _update_memory(self, memory):
|
def _update_memory(self, memory):
|
||||||
|
if len(memory.shape) == 2:
|
||||||
|
return memory[:, self.mel_channels * (self.r - 1) :]
|
||||||
|
else:
|
||||||
return memory[:, :, self.mel_channels * (self.r - 1) :]
|
return memory[:, :, self.mel_channels * (self.r - 1) :]
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
|
|
Loading…
Reference in New Issue