From 528616595c12e62c8f7f949542a2feb14b20bdb8 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 10 May 2018 16:05:03 -0700 Subject: [PATCH] remove 'Variable' from train.py --- train.py | 53 ++++++++++++++++++++--------------------------------- 1 file changed, 20 insertions(+), 33 deletions(-) diff --git a/train.py b/train.py index 0d4a2cf2..88886c3f 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,6 @@ import numpy as np import torch.nn as nn from torch import optim from torch import onnx -from torch.autograd import Variable from torch.utils.data import DataLoader from torch.optim.lr_scheduler import ReduceLROnPlateau from tensorboardX import SummaryWriter @@ -93,29 +92,23 @@ def train(model, criterion, data_loader, optimizer, epoch): optimizer.zero_grad() - # convert inputs to variables - text_input_var = Variable(text_input) - mel_spec_var = Variable(mel_input) - mel_lengths_var = Variable(mel_lengths) - linear_spec_var = Variable(linear_input, volatile=True) - # dispatch data to GPU if use_cuda: - text_input_var = text_input_var.cuda() - mel_spec_var = mel_spec_var.cuda() - mel_lengths_var = mel_lengths_var.cuda() - linear_spec_var = linear_spec_var.cuda() + text_input = text_input.cuda() + mel_input = mel_input.cuda() + mel_lengths = mel_lengths.cuda() + linear_input = linear_input.cuda() # forward pass mel_output, linear_output, alignments =\ - model.forward(text_input_var, mel_spec_var) + model.forward(text_input, mel_input) # loss computation - mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \ + mel_loss = criterion(mel_output, mel_input, mel_lengths) + linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec_var[:, :, :n_priority_freq], - mel_lengths_var) + linear_input[:, :, :n_priority_freq], + mel_lengths) loss = mel_loss + linear_loss # backpass and check the grad norm @@ -157,7 +150,7 @@ def train(model, criterion, data_loader, optimizer, epoch): # Diagnostic visualizations const_spec = linear_output[0].data.cpu().numpy() - gt_spec = linear_spec_var[0].data.cpu().numpy() + gt_spec = linear_input[0].data.cpu().numpy() const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) @@ -215,29 +208,23 @@ def evaluate(model, criterion, data_loader, current_step): mel_input = data[3] mel_lengths = data[4] - # convert inputs to variables - text_input_var = Variable(text_input) - mel_spec_var = Variable(mel_input) - mel_lengths_var = Variable(mel_lengths) - linear_spec_var = Variable(linear_input, volatile=True) - # dispatch data to GPU if use_cuda: - text_input_var = text_input_var.cuda() - mel_spec_var = mel_spec_var.cuda() - mel_lengths_var = mel_lengths_var.cuda() - linear_spec_var = linear_spec_var.cuda() + text_input = text_input.cuda() + mel_input = mel_input.cuda() + mel_lengths = mel_lengths.cuda() + linear_input = linear_input.cuda() # forward pass mel_output, linear_output, alignments =\ - model.forward(text_input_var, mel_spec_var) + model.forward(text_input, mel_input) # loss computation - mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var) - linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \ + mel_loss = criterion(mel_output, mel_input, mel_lengths) + linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ + 0.5 * criterion(linear_output[:, :, :n_priority_freq], - linear_spec_var[:, :, :n_priority_freq], - mel_lengths_var) + linear_input[:, :, :n_priority_freq], + mel_lengths) loss = mel_loss + linear_loss step_time = time.time() - start_time @@ -255,7 +242,7 @@ def evaluate(model, criterion, data_loader, current_step): # Diagnostic visualizations idx = np.random.randint(mel_input.shape[0]) const_spec = linear_output[idx].data.cpu().numpy() - gt_spec = linear_spec_var[idx].data.cpu().numpy() + gt_spec = linear_input[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy() const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)