coqui-tts/TTS/tts/tf/utils/io.py

46 lines
1.3 KiB
Python

import datetime
import pickle
import fsspec
import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
state = {
"model": model.weights,
"optimizer": optimizer,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
}
state.update(kwargs)
with fsspec.open(output_path, "wb") as f:
pickle.dump(state, f)
def load_checkpoint(model, checkpoint_path):
with fsspec.open(checkpoint_path, "rb") as f:
checkpoint = pickle.load(f)
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
try:
chkp_var_value = chkp_var_dict[layer_name]
except KeyError:
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
layer_name = f"{class_name}/{layer_name}"
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if "r" in checkpoint.keys():
model.decoder.set_r(checkpoint["r"])
return model
def load_tflite_model(tflite_path):
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
tflite_model.allocate_tensors()
return tflite_model