From 289a9221d725b0470e500ef68f49537056f3ef8f Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 8 Jul 2020 10:23:53 +0200 Subject: [PATCH] update save_checkpoint for TF tacotron2 --- tf/utils/generic_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tf/utils/generic_utils.py b/tf/utils/generic_utils.py index 6368658d..3d385b10 100644 --- a/tf/utils/generic_utils.py +++ b/tf/utils/generic_utils.py @@ -6,9 +6,7 @@ import numpy as np import tensorflow as tf -def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs): - checkpoint_path = 'tts_tf_checkpoint_{}.pkl'.format(current_step) - checkpoint_path = os.path.join(output_folder, checkpoint_path) +def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): state = { 'model': model.weights, 'optimizer': optimizer, @@ -18,7 +16,7 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k 'r': r } state.update(kwargs) - pickle.dump(state, open(checkpoint_path, 'wb')) + pickle.dump(state, open(output_path, 'wb')) def load_checkpoint(model, checkpoint_path): @@ -27,7 +25,13 @@ def load_checkpoint(model, checkpoint_path): tf_vars = model.weights for tf_var in tf_vars: layer_name = tf_var.name - chkp_var_value = chkp_var_dict[layer_name] + try: + chkp_var_value = chkp_var_dict[layer_name] + except KeyError: + class_name = list(chkp_var_dict.keys())[0].split("/")[0] + layer_name = f"{class_name}/{layer_name}" + chkp_var_value = chkp_var_dict[layer_name] + tf.keras.backend.set_value(tf_var, chkp_var_value) if 'r' in checkpoint.keys(): model.decoder.set_r(checkpoint['r'])