using amp in training

This commit is contained in:
erogol 2020-08-03 12:52:51 +02:00
parent e2151e77a1
commit 10146357a5
3 changed files with 6 additions and 5 deletions

View File

@ -506,7 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
else: else:
optimizer_st = None optimizer_st = None
if c.apex_amp_level: if c.apex_amp_level == "O1":
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from apex import amp from apex import amp
model.cuda() model.cuda()
@ -578,7 +578,7 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Number of output frames:", model.decoder.r) print("\n > Number of output frames:", model.decoder.r)
train_avg_loss_dict, global_step = train(model, criterion, optimizer, train_avg_loss_dict, global_step = train(model, criterion, optimizer,
optimizer_st, scheduler, ap, optimizer_st, scheduler, ap,
global_step, epoch) global_step, epoch, amp)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss'] target_loss = train_avg_loss_dict['avg_postnet_loss']
@ -637,7 +637,7 @@ if __name__ == '__main__':
check_config(c) check_config(c)
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level is 'O1': if c.apex_amp_level == 'O1':
print(" > apex AMP level: ", c.apex_amp_level) print(" > apex AMP level: ", c.apex_amp_level)
OUT_PATH = args.continue_path OUT_PATH = args.continue_path

View File

@ -85,8 +85,8 @@
// TACOTRON PRENET // TACOTRON PRENET
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame. "memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
"prenet_type": "bn", // "original" or "bn". "prenet_type": "bn", // "original" or "bn".
"prenet_dropout": false, // enable/disable dropout at prenet. "prenet_dropout": false, // enable/disable dropout at prenet.
// TACOTRON ATTENTION // TACOTRON ATTENTION
"attention_type": "original", // 'original' or 'graves' "attention_type": "original", // 'original' or 'graves'

View File

@ -178,6 +178,7 @@ def check_config(c):
check_argument('r', c, restricted=True, val_type=int, min_val=1) check_argument('r', c, restricted=True, val_type=int, min_val=1)
check_argument('gradual_training', c, restricted=False, val_type=list) check_argument('gradual_training', c, restricted=False, val_type=list)
check_argument('loss_masking', c, restricted=True, val_type=bool) check_argument('loss_masking', c, restricted=True, val_type=bool)
check_argument('apex_amp_level', c, restricted=False, val_type=str)
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# validation parameters # validation parameters