fixes with new features

This commit is contained in:
Steve Nyemba 2022-11-09 14:28:34 -06:00
parent ce594634e8
commit d469a4904f
2 changed files with 3 additions and 1 deletions

View File

@ -101,6 +101,8 @@ class GNet :
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000) self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10) CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10)
CHECKPOINT_SKIPS = 1 if CHECKPOINT_SKIPS < 1 else CHECKPOINT_SKIPS
# if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS : # if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
# CHECKPOINT_SKIPS = 2 # CHECKPOINT_SKIPS = 2
# self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist() # self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()

View File

@ -4,7 +4,7 @@ import sys
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
args = {"name":"data-maker","version":"1.6.2", args = {"name":"data-maker","version":"1.6.3",
"author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vumc.org","license":"MIT", "author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vumc.org","license":"MIT",
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]} "packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
args["install_requires"] = ['data-transport@git+https://github.com/lnyemba/data-transport.git','tensorflow'] args["install_requires"] = ['data-transport@git+https://github.com/lnyemba/data-transport.git','tensorflow']