diff --git a/train.py b/train.py index 9f6acb39..a9945592 100644 --- a/train.py +++ b/train.py @@ -78,8 +78,8 @@ def train(model, criterion, data_loader, optimizer, epoch): # setup input data text_input = data[0] text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] + linear_spec = data[2] + mel_spec = data[3] mel_lengths = data[4] current_step = num_iter + args.restore_step + \ @@ -227,8 +227,8 @@ def evaluate(model, criterion, data_loader, current_step): # setup input data text_input = data[0] text_lengths = data[1] - linear_input = data[2] - mel_input = data[3] + linear_spec = data[2] + mel_spec = data[3] mel_lengths = data[4] # dispatch data to GPU