bug fix with dimensionalities and removing conditions
This commit is contained in:
parent
4a25af6b13
commit
cac2dd293d
|
@ -79,7 +79,8 @@ class GNet :
|
|||
if 'real' in args :
|
||||
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
|
||||
|
||||
self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
|
||||
# self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
|
||||
self.BATCHSIZE_PER_GPU = 3000 if 'batch_size' not in args else int(args['batch_size'])
|
||||
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
|
||||
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
|
||||
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
|
||||
|
@ -410,7 +411,7 @@ class Train (GNet):
|
|||
dataset = tf.data.Dataset.from_tensor_slices(features_placeholder)
|
||||
# labels_placeholder = None
|
||||
dataset = dataset.repeat(10000)
|
||||
dataset = dataset.batch(batch_size=3000)
|
||||
dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU)
|
||||
dataset = dataset.prefetch(1)
|
||||
# iterator = dataset.make_initializable_iterator()
|
||||
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
|
||||
|
@ -430,7 +431,8 @@ class Train (GNet):
|
|||
(real, label) = iterator.get_next()
|
||||
else:
|
||||
real = iterator.get_next()
|
||||
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
|
||||
label= None
|
||||
loss, w = self.loss(scope=scope, stage=stage, real=real, label=label)
|
||||
#tf.get_variable_scope().reuse_variables()
|
||||
tf.compat.v1.get_variable_scope().reuse_variables()
|
||||
#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
||||
|
@ -465,6 +467,7 @@ class Train (GNet):
|
|||
logs = []
|
||||
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||
|
||||
sess.run(init)
|
||||
|
||||
sess.run(iterator_d.initializer,
|
||||
|
|
2
setup.py
2
setup.py
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
args = {"name":"data-maker","version":"1.1.0","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
|
||||
args = {"name":"data-maker","version":"1.1.1","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
|
||||
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
|
||||
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo']
|
||||
args['url'] = 'https://hiplab.mc.vanderbilt.edu/git/aou/data-maker.git'
|
||||
|
|
Loading…
Reference in New Issue