Merge branch 'dev' of aou/data-maker into master

This commit is contained in:
steve 2020-01-09 11:27:20 -06:00 committed by Gogs
commit 54abbeb42a
2 changed files with 21 additions and 11 deletions

View File

@ -193,9 +193,11 @@ class Generator (GNet):
fake = args['fake']
label = args['label']
y_hat_fake = self.discriminator.network(inputs=fake, label=label)
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs)
tf.add_to_collection('glosses', loss)
#tf.add_to_collection('glosses', loss)
tf.compat.v1.add_to_collection('glosses', loss)
return loss, loss
def load_meta(self, column):
super().load_meta(column)
@ -281,10 +283,12 @@ class Discriminator(GNet):
grad = tf.gradients(y_hat, [x_hat])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
loss = w_distance + 10 * gradient_penalty + sum(all_regs)
tf.add_to_collection('dlosses', loss)
#tf.add_to_collection('dlosses', loss)
tf.compat.v1.add_to_collection('dlosses', loss)
return w_distance, loss
class Train (GNet):
@ -333,10 +337,12 @@ class Train (GNet):
fake = self.generator.network(inputs=z, label=label)
if stage == 'D':
w, loss = self.discriminator.loss(real=real, fake=fake, label=label)
losses = tf.get_collection('dlosses', scope)
#losses = tf.get_collection('dlosses', scope)
losses = tf.compat.v1.get_collection('dlosses', scope)
else:
w, loss = self.generator.loss(fake=fake, label=label)
losses = tf.get_collection('glosses', scope)
#losses = tf.get_collection('glosses', scope)
losses = tf.compat.v1.get_collection('glosses', scope)
total_loss = tf.add_n(losses, name='total_loss')
@ -370,8 +376,10 @@ class Train (GNet):
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
(real, label) = iterator.get_next()
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
tf.get_variable_scope().reuse_variables()
vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
#tf.get_variable_scope().reuse_variables()
tf.compat.v1.get_variable_scope().reuse_variables()
#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
vars_ = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
grads = opt.compute_gradients(loss, vars_)
tower_grads.append(grads)
per_gpu_w.append(w)
@ -394,9 +402,11 @@ class Train (GNet):
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
# saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()
init = tf.global_variables_initializer()
# init = tf.global_variables_initializer()
init = tf.compat.v1.global_variables_initializer()
logs = []
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(init)
sess.run(iterator_d.initializer,
feed_dict={features_placeholder_d: REAL, labels_placeholder_d: LABEL})

View File

@ -4,7 +4,7 @@ import sys
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
args = {"name":"data-maker","version":"1.0.1","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
args = {"name":"data-maker","version":"1.0.2","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.14.0','numpy==1.16.3','pandas','pandas-gbq','pymongo']
args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git'