using amp in training

This commit is contained in:
erogol 2020-08-03 12:52:51 +02:00
parent e2151e77a1
commit 10146357a5
3 changed files with 6 additions and 5 deletions

View File

@ -506,7 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
else:
optimizer_st = None
if c.apex_amp_level:
if c.apex_amp_level == "O1":
# pylint: disable=import-outside-toplevel
from apex import amp
model.cuda()
@ -578,7 +578,7 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Number of output frames:", model.decoder.r)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
optimizer_st, scheduler, ap,
global_step, epoch)
global_step, epoch, amp)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss']
@ -637,7 +637,7 @@ if __name__ == '__main__':
check_config(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level is 'O1':
if c.apex_amp_level == 'O1':
print(" > apex AMP level: ", c.apex_amp_level)
OUT_PATH = args.continue_path

View File

@ -85,8 +85,8 @@
// TACOTRON PRENET
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
"prenet_type": "bn", // "original" or "bn".
"prenet_dropout": false, // enable/disable dropout at prenet.
"prenet_type": "bn", // "original" or "bn".
"prenet_dropout": false, // enable/disable dropout at prenet.
// TACOTRON ATTENTION
"attention_type": "original", // 'original' or 'graves'

View File

@ -178,6 +178,7 @@ def check_config(c):
check_argument('r', c, restricted=True, val_type=int, min_val=1)
check_argument('gradual_training', c, restricted=False, val_type=list)
check_argument('loss_masking', c, restricted=True, val_type=bool)
check_argument('apex_amp_level', c, restricted=False, val_type=str)
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# validation parameters