import datetime import pickle import fsspec import tensorflow as tf def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): state = { "model": model.weights, "optimizer": optimizer, "step": current_step, "epoch": epoch, "date": datetime.date.today().strftime("%B %d, %Y"), "r": r, } state.update(kwargs) with fsspec.open(output_path, "wb") as f: pickle.dump(state, f) def load_checkpoint(model, checkpoint_path): with fsspec.open(checkpoint_path, "rb") as f: checkpoint = pickle.load(f) chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: layer_name = tf_var.name try: chkp_var_value = chkp_var_dict[layer_name] except KeyError: class_name = list(chkp_var_dict.keys())[0].split("/")[0] layer_name = f"{class_name}/{layer_name}" chkp_var_value = chkp_var_dict[layer_name] tf.keras.backend.set_value(tf_var, chkp_var_value) if "r" in checkpoint.keys(): model.decoder.set_r(checkpoint["r"]) return model def load_tflite_model(tflite_path): tflite_model = tf.lite.Interpreter(model_path=tflite_path) tflite_model.allocate_tensors() return tflite_model