mirror of https://github.com/coqui-ai/TTS.git
using amp in training
This commit is contained in:
parent
e2151e77a1
commit
10146357a5
|
@ -506,7 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
else:
|
||||
optimizer_st = None
|
||||
|
||||
if c.apex_amp_level:
|
||||
if c.apex_amp_level == "O1":
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from apex import amp
|
||||
model.cuda()
|
||||
|
@ -578,7 +578,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print("\n > Number of output frames:", model.decoder.r)
|
||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||
optimizer_st, scheduler, ap,
|
||||
global_step, epoch)
|
||||
global_step, epoch, amp)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||
|
@ -637,7 +637,7 @@ if __name__ == '__main__':
|
|||
check_config(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.apex_amp_level is 'O1':
|
||||
if c.apex_amp_level == 'O1':
|
||||
print(" > apex AMP level: ", c.apex_amp_level)
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
|
|
|
@ -178,6 +178,7 @@ def check_config(c):
|
|||
check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
check_argument('apex_amp_level', c, restricted=False, val_type=str)
|
||||
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
|
||||
# validation parameters
|
||||
|
|
Loading…
Reference in New Issue