mirror of https://github.com/coqui-ai/TTS.git
fix for 2 dim memory tensor
This commit is contained in:
parent
e085c4757d
commit
9a2bd7f9af
|
@ -183,7 +183,10 @@ class Decoder(nn.Module):
|
||||||
return outputs, stop_tokens, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
def _update_memory(self, memory):
|
def _update_memory(self, memory):
|
||||||
return memory[:, :, self.mel_channels * (self.r - 1) :]
|
if len(memory.shape) == 2:
|
||||||
|
return memory[:, self.mel_channels * (self.r - 1) :]
|
||||||
|
else:
|
||||||
|
return memory[:, :, self.mel_channels * (self.r - 1) :]
|
||||||
|
|
||||||
def decode(self, memory):
|
def decode(self, memory):
|
||||||
query_input = torch.cat((memory, self.context), -1)
|
query_input = torch.cat((memory, self.context), -1)
|
||||||
|
|
Loading…
Reference in New Issue