mirror of https://github.com/coqui-ai/TTS.git
make attn guiding optional #2
This commit is contained in:
parent
f197ab1e28
commit
3c8ded5a18
14
train.py
14
train.py
|
@ -72,6 +72,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
print(" | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
progbar = Progbar(len(data_loader.dataset) / c.batch_size)
|
||||||
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
n_priority_freq = int(3000 / (c.sample_rate * 0.5) * c.num_freq)
|
||||||
|
progbar_display = {}
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -127,6 +128,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
attention_loss = criterion(alignments, M, mel_lengths_var)
|
attention_loss = criterion(alignments, M, mel_lengths_var)
|
||||||
loss += mk * attention_loss
|
loss += mk * attention_loss
|
||||||
avg_attn_loss += attention_loss.data[0]
|
avg_attn_loss += attention_loss.data[0]
|
||||||
|
progbar_display['attn_loss'] = attention_loss.data[0]
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -139,14 +141,14 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
|
progbar_display['total_loss'] = loss.data[0]
|
||||||
|
progbar_display['linear_loss'] = linear_loss.data[0]
|
||||||
|
progbar_display['mel_loss'] = mel_loss.data[0]
|
||||||
|
progbar_display['grad_norm'] = grad_norm
|
||||||
|
|
||||||
# update
|
# update
|
||||||
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
progbar.update(num_iter+1, values=[tuple(progbar_display.iteritems())])
|
||||||
('linear_loss',
|
|
||||||
linear_loss.data[0]),
|
|
||||||
('mel_loss', mel_loss.data[0]),
|
|
||||||
('attn_loss', attention_loss.data[0]),
|
|
||||||
('grad_norm', grad_norm)])
|
|
||||||
avg_linear_loss += linear_loss.data[0]
|
avg_linear_loss += linear_loss.data[0]
|
||||||
avg_mel_loss += mel_loss.data[0]
|
avg_mel_loss += mel_loss.data[0]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue