mirror of https://github.com/coqui-ai/TTS.git
train.py - add with torch.no_grad():
This commit is contained in:
parent
7c40455edd
commit
b4bd713581
65
train.py
65
train.py
|
@ -191,45 +191,46 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
print(" | > Validation")
|
print(" | > Validation")
|
||||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
for num_iter, data in enumerate(data_loader):
|
with torch.no_grad():
|
||||||
start_time = time.time()
|
for num_iter, data in enumerate(data_loader):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
# setup input data
|
# setup input data
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
mel_lengths = data[4]
|
mel_lengths = data[4]
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input = text_input.cuda()
|
text_input = text_input.cuda()
|
||||||
mel_input = mel_input.cuda()
|
mel_input = mel_input.cuda()
|
||||||
mel_lengths = mel_lengths.cuda()
|
mel_lengths = mel_lengths.cuda()
|
||||||
linear_input = linear_input.cuda()
|
linear_input = linear_input.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments =\
|
mel_output, linear_output, alignments =\
|
||||||
model.forward(text_input, mel_input)
|
model.forward(text_input, mel_input)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_input[:, :, :n_priority_freq],
|
linear_input[:, :, :n_priority_freq],
|
||||||
mel_lengths)
|
mel_lengths)
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
# update
|
# update
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
progbar.update(num_iter+1, values=[('total_loss', loss.item()),
|
||||||
('linear_loss', linear_loss.item()),
|
('linear_loss', linear_loss.item()),
|
||||||
('mel_loss', mel_loss.item())])
|
('mel_loss', mel_loss.item())])
|
||||||
|
|
||||||
avg_linear_loss += linear_loss.item()
|
avg_linear_loss += linear_loss.item()
|
||||||
avg_mel_loss += mel_loss.item()
|
avg_mel_loss += mel_loss.item()
|
||||||
|
|
||||||
# Diagnostic visualizations
|
# Diagnostic visualizations
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
|
|
Loading…
Reference in New Issue