From 2e1390dfb1beed54b54d7c85976a6ab30ef2890e Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 3 Aug 2020 13:04:07 +0200 Subject: [PATCH] loss scaling for O1 optimization --- TTS/bin/train_tts.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 3b6bfe35..7925c54e 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -170,7 +170,12 @@ def train(model, criterion, optimizer, optimizer_st, scheduler, text_lengths) # backward pass - loss_dict['loss'].backward() + if amp is not None: + with amp.scale_loss( loss_dict['loss'], optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss_dict['loss'].backward() + optimizer, current_lr = adam_weight_decay(optimizer) if amp: amp_opt_params = amp.master_params(optimizer)