mirror of https://github.com/coqui-ai/TTS.git
loss scaling for O1 optimization
This commit is contained in:
parent
0bed77944c
commit
2e1390dfb1
|
@ -170,7 +170,12 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
text_lengths)
|
||||
|
||||
# backward pass
|
||||
loss_dict['loss'].backward()
|
||||
if amp is not None:
|
||||
with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
if amp:
|
||||
amp_opt_params = amp.master_params(optimizer)
|
||||
|
|
Loading…
Reference in New Issue