diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index e3f78176..aea6575d 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -128,7 +128,7 @@ class LJSpeechDataset(Dataset): linear = torch.FloatTensor(linear) mel = torch.FloatTensor(mel) mel_lengths = torch.LongTensor(mel_lengths) - stop_targets = torch.FloatTensor(stop_targets).squeeze() + stop_targets = torch.FloatTensor(stop_targets) return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0] diff --git a/models/tacotron.py b/models/tacotron.py index 532db6b8..9495bf7a 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -33,7 +33,7 @@ class Tacotron(nn.Module): # Reshape # batch x time x dim mel_outputs = mel_outputs.view(B, -1, self.mel_dim) - stop_tokens = self.stopnet(mel_outputs) + stop_tokens = self.stopnet(mel_outputs).squeeze() linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments, stop_tokens