bug fix with dimensionalities and removing conditions

This commit is contained in:
Steve Nyemba 2020-02-18 16:56:24 -06:00
parent 4a25af6b13
commit cac2dd293d
2 changed files with 7 additions and 4 deletions

View File

@ -79,7 +79,8 @@ class GNet :
if 'real' in args : if 'real' in args :
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM] self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256 # self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
self.BATCHSIZE_PER_GPU = 3000 if 'batch_size' not in args else int(args['batch_size'])
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000) self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs']) self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
@ -410,7 +411,7 @@ class Train (GNet):
dataset = tf.data.Dataset.from_tensor_slices(features_placeholder) dataset = tf.data.Dataset.from_tensor_slices(features_placeholder)
# labels_placeholder = None # labels_placeholder = None
dataset = dataset.repeat(10000) dataset = dataset.repeat(10000)
dataset = dataset.batch(batch_size=3000) dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU)
dataset = dataset.prefetch(1) dataset = dataset.prefetch(1)
# iterator = dataset.make_initializable_iterator() # iterator = dataset.make_initializable_iterator()
iterator = tf.compat.v1.data.make_initializable_iterator(dataset) iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
@ -430,7 +431,8 @@ class Train (GNet):
(real, label) = iterator.get_next() (real, label) = iterator.get_next()
else: else:
real = iterator.get_next() real = iterator.get_next()
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL) label= None
loss, w = self.loss(scope=scope, stage=stage, real=real, label=label)
#tf.get_variable_scope().reuse_variables() #tf.get_variable_scope().reuse_variables()
tf.compat.v1.get_variable_scope().reuse_variables() tf.compat.v1.get_variable_scope().reuse_variables()
#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage) #vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
@ -465,6 +467,7 @@ class Train (GNet):
logs = [] logs = []
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(init) sess.run(init)
sess.run(iterator_d.initializer, sess.run(iterator_d.initializer,

View File

@ -4,7 +4,7 @@ import sys
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
args = {"name":"data-maker","version":"1.1.0","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT", args = {"name":"data-maker","version":"1.1.1","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]} "packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo'] args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo']
args['url'] = 'https://hiplab.mc.vanderbilt.edu/git/aou/data-maker.git' args['url'] = 'https://hiplab.mc.vanderbilt.edu/git/aou/data-maker.git'