diff --git a/data/gan.py b/data/gan.py index 3391b78..7bd17ee 100644 --- a/data/gan.py +++ b/data/gan.py @@ -193,9 +193,11 @@ class Generator (GNet): fake = args['fake'] label = args['label'] y_hat_fake = self.discriminator.network(inputs=fake, label=label) - all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + #all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES) loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs) - tf.add_to_collection('glosses', loss) + #tf.add_to_collection('glosses', loss) + tf.compat.v1.add_to_collection('glosses', loss) return loss, loss def load_meta(self, column): super().load_meta(column) @@ -281,10 +283,12 @@ class Discriminator(GNet): grad = tf.gradients(y_hat, [x_hat])[0] slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1)) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) - all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + #all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES) w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake) loss = w_distance + 10 * gradient_penalty + sum(all_regs) - tf.add_to_collection('dlosses', loss) + #tf.add_to_collection('dlosses', loss) + tf.compat.v1.add_to_collection('dlosses', loss) return w_distance, loss class Train (GNet): @@ -333,10 +337,12 @@ class Train (GNet): fake = self.generator.network(inputs=z, label=label) if stage == 'D': w, loss = self.discriminator.loss(real=real, fake=fake, label=label) - losses = tf.get_collection('dlosses', scope) + #losses = tf.get_collection('dlosses', scope) + losses = tf.compat.v1.get_collection('dlosses', scope) else: w, loss = self.generator.loss(fake=fake, label=label) - losses = tf.get_collection('glosses', scope) + #losses = tf.get_collection('glosses', scope) + losses = tf.compat.v1.get_collection('glosses', scope) total_loss = tf.add_n(losses, name='total_loss') @@ -370,8 +376,10 @@ class Train (GNet): with tf.name_scope('%s_%d' % ('TOWER', i)) as scope: (real, label) = iterator.get_next() loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL) - tf.get_variable_scope().reuse_variables() - vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage) + #tf.get_variable_scope().reuse_variables() + tf.compat.v1.get_variable_scope().reuse_variables() + #vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage) + vars_ = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=stage) grads = opt.compute_gradients(loss, vars_) tower_grads.append(grads) per_gpu_w.append(w) @@ -394,9 +402,11 @@ class Train (GNet): train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g) # saver = tf.train.Saver() saver = tf.compat.v1.train.Saver() - init = tf.global_variables_initializer() + # init = tf.global_variables_initializer() + init = tf.compat.v1.global_variables_initializer() logs = [] - with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: + #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: + with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: sess.run(init) sess.run(iterator_d.initializer, feed_dict={features_placeholder_d: REAL, labels_placeholder_d: LABEL}) diff --git a/setup.py b/setup.py index aa45602..762ced3 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ import sys def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() -args = {"name":"data-maker","version":"1.0.1","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT", +args = {"name":"data-maker","version":"1.0.2","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT", "packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]} args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.14.0','numpy==1.16.3','pandas','pandas-gbq','pymongo'] args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git'