diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index aea6575d..e3f78176 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -128,7 +128,7 @@ class LJSpeechDataset(Dataset): linear = torch.FloatTensor(linear) mel = torch.FloatTensor(mel) mel_lengths = torch.LongTensor(mel_lengths) - stop_targets = torch.FloatTensor(stop_targets) + stop_targets = torch.FloatTensor(stop_targets).squeeze() return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]