bug fix: checkpoints

This commit is contained in:
Steve Nyemba 2022-09-16 22:39:25 -05:00
parent 4be340ec08
commit 209a7b8ee5
1 changed files with 5 additions and 3 deletions

View File

@ -103,7 +103,7 @@ class GNet :
CHECKPOINT_SKIPS = 10 CHECKPOINT_SKIPS = 10
if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS : if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
CHECKPOINT_SKIPS = 2 CHECKPOINT_SKIPS = 2
self.CHECKPOINTS = np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist() self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
@ -529,7 +529,7 @@ class Train (GNet):
train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d) train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d)
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g) train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
# saver = tf.train.Saver() # saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver() saver = tf.compat.v1.train.Saver(max_to_keep=len(self.CHECKPOINTS))
# init = tf.global_variables_initializer() # init = tf.global_variables_initializer()
init = tf.compat.v1.global_variables_initializer() init = tf.compat.v1.global_variables_initializer()
logs = [] logs = []
@ -587,7 +587,9 @@ class Train (GNet):
tf.compat.v1.reset_default_graph() tf.compat.v1.reset_default_graph()
# #
# let's sort the epochs we've logged thus far (if any) # let's sort the epochs we've logged thus far (if any)
# Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted
# #
# self.logs['epochs'] = self.logs['epochs'][-5:]
self.logs['epochs'].sort(key=lambda _item: _item['loss']) self.logs['epochs'].sort(key=lambda _item: _item['loss'])
if self.logger : if self.logger :
_log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']} _log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']}