diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index 26c08150..09d40285 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -539,12 +539,16 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') try: - # TODO: fix optimizer init, model.cuda() needs to be called before + print(" > Restoring Model.") + model.load_state_dict(checkpoint['model']) # optimizer restore - # optimizer.load_state_dict(checkpoint['optimizer']) + print(" > Restoring Optimizer.") + optimizer.load_state_dict(checkpoint['optimizer']) + if "scaler" in checkpoint and c.mixed_precision: + print(" > Restoring AMP Scaler...") + scaler.load_state_dict(checkpoint["scaler"]) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) except KeyError: print(" > Partial model initialization.") model_dict = model.state_dict()