diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index f5bc364f..3b6bfe35 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -506,7 +506,7 @@ def main(args): # pylint: disable=redefined-outer-name else: optimizer_st = None - if c.apex_amp_level: + if c.apex_amp_level == "O1": # pylint: disable=import-outside-toplevel from apex import amp model.cuda() @@ -578,7 +578,7 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Number of output frames:", model.decoder.r) train_avg_loss_dict, global_step = train(model, criterion, optimizer, optimizer_st, scheduler, ap, - global_step, epoch) + global_step, epoch, amp) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_postnet_loss'] @@ -637,7 +637,7 @@ if __name__ == '__main__': check_config(c) _ = os.path.dirname(os.path.realpath(__file__)) - if c.apex_amp_level is 'O1': + if c.apex_amp_level == 'O1': print(" > apex AMP level: ", c.apex_amp_level) OUT_PATH = args.continue_path diff --git a/TTS/tts/configs/config.json b/TTS/tts/configs/config.json index 6a60fe81..cd4595b9 100644 --- a/TTS/tts/configs/config.json +++ b/TTS/tts/configs/config.json @@ -85,8 +85,8 @@ // TACOTRON PRENET "memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame. - "prenet_type": "bn", // "original" or "bn". - "prenet_dropout": false, // enable/disable dropout at prenet. + "prenet_type": "bn", // "original" or "bn". + "prenet_dropout": false, // enable/disable dropout at prenet. // TACOTRON ATTENTION "attention_type": "original", // 'original' or 'graves' diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index a781b5a4..f06b52de 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -178,6 +178,7 @@ def check_config(c): check_argument('r', c, restricted=True, val_type=int, min_val=1) check_argument('gradual_training', c, restricted=False, val_type=list) check_argument('loss_masking', c, restricted=True, val_type=bool) + check_argument('apex_amp_level', c, restricted=False, val_type=str) # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) # validation parameters