mirror of https://github.com/coqui-ai/TTS.git
remove 'Variable' from train.py
This commit is contained in:
parent
b087c0b5ec
commit
a2eeec31be
53
train.py
53
train.py
|
@ -13,7 +13,6 @@ import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch import onnx
|
from torch import onnx
|
||||||
from torch.autograd import Variable
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
|
@ -93,29 +92,23 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# convert inputs to variables
|
|
||||||
text_input_var = Variable(text_input)
|
|
||||||
mel_spec_var = Variable(mel_input)
|
|
||||||
mel_lengths_var = Variable(mel_lengths)
|
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = text_input_var.cuda()
|
text_input = text_input.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths_var = mel_lengths_var.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_input = linear_input.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments =\
|
||||||
model.forward(text_input_var, mel_spec_var)
|
model.forward(text_input, mel_input)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[:, :, :n_priority_freq],
|
linear_input[:, :, :n_priority_freq],
|
||||||
mel_lengths_var)
|
mel_lengths)
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm
|
||||||
|
@ -157,7 +150,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
const_spec = linear_output[0].data.cpu().numpy()
|
const_spec = linear_output[0].data.cpu().numpy()
|
||||||
gt_spec = linear_spec_var[0].data.cpu().numpy()
|
gt_spec = linear_input[0].data.cpu().numpy()
|
||||||
|
|
||||||
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
||||||
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
|
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
|
||||||
|
@ -215,29 +208,23 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
|
|
||||||
# convert inputs to variables
|
|
||||||
text_input_var = Variable(text_input)
|
|
||||||
mel_spec_var = Variable(mel_input)
|
|
||||||
mel_lengths_var = Variable(mel_lengths)
|
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = text_input_var.cuda()
|
text_input = text_input.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths_var = mel_lengths_var.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_input = linear_input.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments =\
|
||||||
model.forward(text_input_var, mel_spec_var)
|
model.forward(text_input, mel_input)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[:, :, :n_priority_freq],
|
linear_input[:, :, :n_priority_freq],
|
||||||
mel_lengths_var)
|
mel_lengths)
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
@ -255,7 +242,7 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
const_spec = linear_output[idx].data.cpu().numpy()
|
const_spec = linear_output[idx].data.cpu().numpy()
|
||||||
gt_spec = linear_spec_var[idx].data.cpu().numpy()
|
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
|
||||||
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
||||||
|
|
Loading…
Reference in New Issue