Merge branch 'dev' of aou/data-maker into master
This commit is contained in:
commit
54abbeb42a
30
data/gan.py
30
data/gan.py
|
@ -193,9 +193,11 @@ class Generator (GNet):
|
|||
fake = args['fake']
|
||||
label = args['label']
|
||||
y_hat_fake = self.discriminator.network(inputs=fake, label=label)
|
||||
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
||||
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
||||
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
||||
loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs)
|
||||
tf.add_to_collection('glosses', loss)
|
||||
#tf.add_to_collection('glosses', loss)
|
||||
tf.compat.v1.add_to_collection('glosses', loss)
|
||||
return loss, loss
|
||||
def load_meta(self, column):
|
||||
super().load_meta(column)
|
||||
|
@ -281,10 +283,12 @@ class Discriminator(GNet):
|
|||
grad = tf.gradients(y_hat, [x_hat])[0]
|
||||
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
|
||||
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
|
||||
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
||||
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
||||
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
||||
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
|
||||
loss = w_distance + 10 * gradient_penalty + sum(all_regs)
|
||||
tf.add_to_collection('dlosses', loss)
|
||||
#tf.add_to_collection('dlosses', loss)
|
||||
tf.compat.v1.add_to_collection('dlosses', loss)
|
||||
|
||||
return w_distance, loss
|
||||
class Train (GNet):
|
||||
|
@ -333,10 +337,12 @@ class Train (GNet):
|
|||
fake = self.generator.network(inputs=z, label=label)
|
||||
if stage == 'D':
|
||||
w, loss = self.discriminator.loss(real=real, fake=fake, label=label)
|
||||
losses = tf.get_collection('dlosses', scope)
|
||||
#losses = tf.get_collection('dlosses', scope)
|
||||
losses = tf.compat.v1.get_collection('dlosses', scope)
|
||||
else:
|
||||
w, loss = self.generator.loss(fake=fake, label=label)
|
||||
losses = tf.get_collection('glosses', scope)
|
||||
#losses = tf.get_collection('glosses', scope)
|
||||
losses = tf.compat.v1.get_collection('glosses', scope)
|
||||
|
||||
total_loss = tf.add_n(losses, name='total_loss')
|
||||
|
||||
|
@ -370,8 +376,10 @@ class Train (GNet):
|
|||
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
|
||||
(real, label) = iterator.get_next()
|
||||
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
||||
#tf.get_variable_scope().reuse_variables()
|
||||
tf.compat.v1.get_variable_scope().reuse_variables()
|
||||
#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
||||
vars_ = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
||||
grads = opt.compute_gradients(loss, vars_)
|
||||
tower_grads.append(grads)
|
||||
per_gpu_w.append(w)
|
||||
|
@ -394,9 +402,11 @@ class Train (GNet):
|
|||
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
|
||||
# saver = tf.train.Saver()
|
||||
saver = tf.compat.v1.train.Saver()
|
||||
init = tf.global_variables_initializer()
|
||||
# init = tf.global_variables_initializer()
|
||||
init = tf.compat.v1.global_variables_initializer()
|
||||
logs = []
|
||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||
sess.run(init)
|
||||
sess.run(iterator_d.initializer,
|
||||
feed_dict={features_placeholder_d: REAL, labels_placeholder_d: LABEL})
|
||||
|
|
2
setup.py
2
setup.py
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
args = {"name":"data-maker","version":"1.0.1","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
|
||||
args = {"name":"data-maker","version":"1.0.2","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
|
||||
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
|
||||
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.14.0','numpy==1.16.3','pandas','pandas-gbq','pymongo']
|
||||
args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git'
|
||||
|
|
Loading…
Reference in New Issue