bug fix: checkpoints
This commit is contained in:
parent
4be340ec08
commit
209a7b8ee5
|
@ -103,7 +103,7 @@ class GNet :
|
|||
CHECKPOINT_SKIPS = 10
|
||||
if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
|
||||
CHECKPOINT_SKIPS = 2
|
||||
self.CHECKPOINTS = np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
|
||||
self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
|
||||
|
||||
|
||||
|
||||
|
@ -529,7 +529,7 @@ class Train (GNet):
|
|||
train_d, w_distance, iterator_d, features_placeholder_d, labels_placeholder_d = self.network(stage='D', opt=opt_d)
|
||||
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
|
||||
# saver = tf.train.Saver()
|
||||
saver = tf.compat.v1.train.Saver()
|
||||
saver = tf.compat.v1.train.Saver(max_to_keep=len(self.CHECKPOINTS))
|
||||
# init = tf.global_variables_initializer()
|
||||
init = tf.compat.v1.global_variables_initializer()
|
||||
logs = []
|
||||
|
@ -564,7 +564,7 @@ class Train (GNet):
|
|||
|
||||
# if epoch % self.MAX_EPOCHS == 0:
|
||||
# if epoch in [5,10,20,50,75, self.MAX_EPOCHS] :
|
||||
if epoch in self.CHECKPOINTS :
|
||||
if epoch in self.CHECKPOINTS :
|
||||
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
||||
suffix = self.CONTEXT #self.get.suffix()
|
||||
_name = os.sep.join([self.train_dir,str(epoch),suffix])
|
||||
|
@ -587,7 +587,9 @@ class Train (GNet):
|
|||
tf.compat.v1.reset_default_graph()
|
||||
#
|
||||
# let's sort the epochs we've logged thus far (if any)
|
||||
# Take on the last five checkpoints https://stackoverflow.com/questions/41018454/tensorflow-checkpoint-models-getting-deleted
|
||||
#
|
||||
# self.logs['epochs'] = self.logs['epochs'][-5:]
|
||||
self.logs['epochs'].sort(key=lambda _item: _item['loss'])
|
||||
if self.logger :
|
||||
_log = {'module':'gan-train','action':'epochs','input':self.logs['epochs']}
|
||||
|
|
Loading…
Reference in New Issue