update save_checkpoint for TF tacotron2

This commit is contained in:
erogol 2020-07-08 10:23:53 +02:00
parent dfd5e3cbfc
commit 289a9221d7
1 changed files with 9 additions and 5 deletions

View File

@ -6,9 +6,7 @@ import numpy as np
import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
checkpoint_path = 'tts_tf_checkpoint_{}.pkl'.format(current_step)
checkpoint_path = os.path.join(output_folder, checkpoint_path)
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
state = {
'model': model.weights,
'optimizer': optimizer,
@ -18,7 +16,7 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
'r': r
}
state.update(kwargs)
pickle.dump(state, open(checkpoint_path, 'wb'))
pickle.dump(state, open(output_path, 'wb'))
def load_checkpoint(model, checkpoint_path):
@ -27,7 +25,13 @@ def load_checkpoint(model, checkpoint_path):
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
chkp_var_value = chkp_var_dict[layer_name]
try:
chkp_var_value = chkp_var_dict[layer_name]
except KeyError:
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
layer_name = f"{class_name}/{layer_name}"
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if 'r' in checkpoint.keys():
model.decoder.set_r(checkpoint['r'])