diff --git a/train.py b/train.py index 7011b024..7b5a02fe 100644 --- a/train.py +++ b/train.py @@ -128,8 +128,6 @@ def train(model, criterion, data_loader, optimizer, epoch): + 0.5 * criterion(linear_output[:, :, :n_priority_freq], linear_spec_var[:, :, :n_priority_freq], mel_lengths_var) - print(M.shape) - print(alignments.shape) attention_loss = criterion(alignments, M, mel_lengths_var) loss = mel_loss + linear_loss + 0.2 * attention_loss