From a4474abd8365939f64db1799a298fa7b5afe3588 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 6 Mar 2019 13:10:54 +0100 Subject: [PATCH] tacotron parse output bug fix --- layers/tacotron.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index e895968c..6865c902 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -376,12 +376,12 @@ class Decoder(nn.Module): self.attention = inputs.data.new(B, T).zero_() self.attention_cum = inputs.data.new(B, T).zero_() - def _parse_outputs(self, outputs, stop_tokens, attentions): + def _parse_outputs(self, outputs, attentions, stop_tokens): # Back to batch first attentions = torch.stack(attentions).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - stop_tokens = torch.stack(stop_tokens).transpose(0, 1) - return outputs, stop_tokens, attentions + stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1) + return outputs, attentions, stop_tokens def decode(self, inputs,