remove 'Variable' from train.py

This commit is contained in:
Eren Golge 2018-05-10 16:05:03 -07:00
parent b087c0b5ec
commit a2eeec31be
1 changed files with 20 additions and 33 deletions

View File

@ -13,7 +13,6 @@ import numpy as np
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch import onnx from torch import onnx
from torch.autograd import Variable
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
@ -93,29 +92,23 @@ def train(model, criterion, data_loader, optimizer, epoch):
optimizer.zero_grad() optimizer.zero_grad()
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
mel_lengths_var = Variable(mel_lengths)
linear_spec_var = Variable(linear_input, volatile=True)
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
text_input_var = text_input_var.cuda() text_input = text_input.cuda()
mel_spec_var = mel_spec_var.cuda() mel_input = mel_input.cuda()
mel_lengths_var = mel_lengths_var.cuda() mel_lengths = mel_lengths.cuda()
linear_spec_var = linear_spec_var.cuda() linear_input = linear_input.cuda()
# forward pass # forward pass
mel_output, linear_output, alignments =\ mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_spec_var) model.forward(text_input, mel_input)
# loss computation # loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var) mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \ linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq],
mel_lengths_var) mel_lengths)
loss = mel_loss + linear_loss loss = mel_loss + linear_loss
# backpass and check the grad norm # backpass and check the grad norm
@ -157,7 +150,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
# Diagnostic visualizations # Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy() const_spec = linear_output[0].data.cpu().numpy()
gt_spec = linear_spec_var[0].data.cpu().numpy() gt_spec = linear_input[0].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap) gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
@ -215,29 +208,23 @@ def evaluate(model, criterion, data_loader, current_step):
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
mel_lengths_var = Variable(mel_lengths)
linear_spec_var = Variable(linear_input, volatile=True)
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
text_input_var = text_input_var.cuda() text_input = text_input.cuda()
mel_spec_var = mel_spec_var.cuda() mel_input = mel_input.cuda()
mel_lengths_var = mel_lengths_var.cuda() mel_lengths = mel_lengths.cuda()
linear_spec_var = linear_spec_var.cuda() linear_input = linear_input.cuda()
# forward pass # forward pass
mel_output, linear_output, alignments =\ mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_spec_var) model.forward(text_input, mel_input)
# loss computation # loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var) mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \ linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], + 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[:, :, :n_priority_freq], linear_input[:, :, :n_priority_freq],
mel_lengths_var) mel_lengths)
loss = mel_loss + linear_loss loss = mel_loss + linear_loss
step_time = time.time() - start_time step_time = time.time() - start_time
@ -255,7 +242,7 @@ def evaluate(model, criterion, data_loader, current_step):
# Diagnostic visualizations # Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0]) idx = np.random.randint(mel_input.shape[0])
const_spec = linear_output[idx].data.cpu().numpy() const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_spec_var[idx].data.cpu().numpy() gt_spec = linear_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap) const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)