loss scaling for O1 optimization

This commit is contained in:
erogol 2020-08-03 13:04:07 +02:00
parent 0bed77944c
commit 2e1390dfb1
1 changed files with 6 additions and 1 deletions

View File

@ -170,7 +170,12 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
text_lengths)
# backward pass
loss_dict['loss'].backward()
if amp is not None:
with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss_dict['loss'].backward()
optimizer, current_lr = adam_weight_decay(optimizer)
if amp:
amp_opt_params = amp.master_params(optimizer)