mirror of https://github.com/coqui-ai/TTS.git
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
import datetime
|
|
import pickle
|
|
|
|
import fsspec
|
|
import tensorflow as tf
|
|
|
|
|
|
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
|
state = {
|
|
"model": model.weights,
|
|
"optimizer": optimizer,
|
|
"step": current_step,
|
|
"epoch": epoch,
|
|
"date": datetime.date.today().strftime("%B %d, %Y"),
|
|
"r": r,
|
|
}
|
|
state.update(kwargs)
|
|
with fsspec.open(output_path, "wb") as f:
|
|
pickle.dump(state, f)
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_path):
|
|
with fsspec.open(checkpoint_path, "rb") as f:
|
|
checkpoint = pickle.load(f)
|
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
|
tf_vars = model.weights
|
|
for tf_var in tf_vars:
|
|
layer_name = tf_var.name
|
|
try:
|
|
chkp_var_value = chkp_var_dict[layer_name]
|
|
except KeyError:
|
|
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
|
|
layer_name = f"{class_name}/{layer_name}"
|
|
chkp_var_value = chkp_var_dict[layer_name]
|
|
|
|
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
|
if "r" in checkpoint.keys():
|
|
model.decoder.set_r(checkpoint["r"])
|
|
return model
|
|
|
|
|
|
def load_tflite_model(tflite_path):
|
|
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
|
|
tflite_model.allocate_tensors()
|
|
return tflite_model
|