bug fix with dimensionalities and removing conditions
This commit is contained in:
parent
4a25af6b13
commit
cac2dd293d
|
@ -79,7 +79,8 @@ class GNet :
|
||||||
if 'real' in args :
|
if 'real' in args :
|
||||||
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
|
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
|
||||||
|
|
||||||
self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
|
# self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
|
||||||
|
self.BATCHSIZE_PER_GPU = 3000 if 'batch_size' not in args else int(args['batch_size'])
|
||||||
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
|
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
|
||||||
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
|
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
|
||||||
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
|
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
|
||||||
|
@ -410,7 +411,7 @@ class Train (GNet):
|
||||||
dataset = tf.data.Dataset.from_tensor_slices(features_placeholder)
|
dataset = tf.data.Dataset.from_tensor_slices(features_placeholder)
|
||||||
# labels_placeholder = None
|
# labels_placeholder = None
|
||||||
dataset = dataset.repeat(10000)
|
dataset = dataset.repeat(10000)
|
||||||
dataset = dataset.batch(batch_size=3000)
|
dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU)
|
||||||
dataset = dataset.prefetch(1)
|
dataset = dataset.prefetch(1)
|
||||||
# iterator = dataset.make_initializable_iterator()
|
# iterator = dataset.make_initializable_iterator()
|
||||||
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
|
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
|
||||||
|
@ -430,7 +431,8 @@ class Train (GNet):
|
||||||
(real, label) = iterator.get_next()
|
(real, label) = iterator.get_next()
|
||||||
else:
|
else:
|
||||||
real = iterator.get_next()
|
real = iterator.get_next()
|
||||||
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
|
label= None
|
||||||
|
loss, w = self.loss(scope=scope, stage=stage, real=real, label=label)
|
||||||
#tf.get_variable_scope().reuse_variables()
|
#tf.get_variable_scope().reuse_variables()
|
||||||
tf.compat.v1.get_variable_scope().reuse_variables()
|
tf.compat.v1.get_variable_scope().reuse_variables()
|
||||||
#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
||||||
|
@ -465,6 +467,7 @@ class Train (GNet):
|
||||||
logs = []
|
logs = []
|
||||||
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||||
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||||
|
|
||||||
sess.run(init)
|
sess.run(init)
|
||||||
|
|
||||||
sess.run(iterator_d.initializer,
|
sess.run(iterator_d.initializer,
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -4,7 +4,7 @@ import sys
|
||||||
|
|
||||||
def read(fname):
|
def read(fname):
|
||||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||||
args = {"name":"data-maker","version":"1.1.0","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
|
args = {"name":"data-maker","version":"1.1.1","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
|
||||||
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
|
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
|
||||||
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo']
|
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo']
|
||||||
args['url'] = 'https://hiplab.mc.vanderbilt.edu/git/aou/data-maker.git'
|
args['url'] = 'https://hiplab.mc.vanderbilt.edu/git/aou/data-maker.git'
|
||||||
|
|
Loading…
Reference in New Issue