bug fix with compatibility (tf 2.0)

This commit is contained in:
Steve Nyemba 2020-01-09 11:20:40 -06:00
parent ff076013b9
commit 8d85c0b0cc
1 changed files with 20 additions and 10 deletions

View File

@ -193,9 +193,11 @@ class Generator (GNet):
fake = args['fake'] fake = args['fake']
label = args['label'] label = args['label']
y_hat_fake = self.discriminator.network(inputs=fake, label=label) y_hat_fake = self.discriminator.network(inputs=fake, label=label)
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) #all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs) loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs)
tf.add_to_collection('glosses', loss) #tf.add_to_collection('glosses', loss)
tf.compat.v1.add_to_collection('glosses', loss)
return loss, loss return loss, loss
def load_meta(self, column): def load_meta(self, column):
super().load_meta(column) super().load_meta(column)
@ -281,10 +283,12 @@ class Discriminator(GNet):
grad = tf.gradients(y_hat, [x_hat])[0] grad = tf.gradients(y_hat, [x_hat])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1)) slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) #all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake) w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
loss = w_distance + 10 * gradient_penalty + sum(all_regs) loss = w_distance + 10 * gradient_penalty + sum(all_regs)
tf.add_to_collection('dlosses', loss) #tf.add_to_collection('dlosses', loss)
tf.compat.v1.add_to_collection('dlosses', loss)
return w_distance, loss return w_distance, loss
class Train (GNet): class Train (GNet):
@ -333,10 +337,12 @@ class Train (GNet):
fake = self.generator.network(inputs=z, label=label) fake = self.generator.network(inputs=z, label=label)
if stage == 'D': if stage == 'D':
w, loss = self.discriminator.loss(real=real, fake=fake, label=label) w, loss = self.discriminator.loss(real=real, fake=fake, label=label)
losses = tf.get_collection('dlosses', scope) #losses = tf.get_collection('dlosses', scope)
losses = tf.compat.v1.get_collection('dlosses', scope)
else: else:
w, loss = self.generator.loss(fake=fake, label=label) w, loss = self.generator.loss(fake=fake, label=label)
losses = tf.get_collection('glosses', scope) #losses = tf.get_collection('glosses', scope)
losses = tf.compat.v1.get_collection('glosses', scope)
total_loss = tf.add_n(losses, name='total_loss') total_loss = tf.add_n(losses, name='total_loss')
@ -370,8 +376,10 @@ class Train (GNet):
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope: with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
(real, label) = iterator.get_next() (real, label) = iterator.get_next()
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL) loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
tf.get_variable_scope().reuse_variables() #tf.get_variable_scope().reuse_variables()
vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage) tf.compat.v1.get_variable_scope().reuse_variables()
#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
vars_ = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
grads = opt.compute_gradients(loss, vars_) grads = opt.compute_gradients(loss, vars_)
tower_grads.append(grads) tower_grads.append(grads)
per_gpu_w.append(w) per_gpu_w.append(w)
@ -394,9 +402,11 @@ class Train (GNet):
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g) train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
# saver = tf.train.Saver() # saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver() saver = tf.compat.v1.train.Saver()
init = tf.global_variables_initializer() # init = tf.global_variables_initializer()
init = tf.compat.v1.global_variables_initializer()
logs = [] logs = []
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(init) sess.run(init)
sess.run(iterator_d.initializer, sess.run(iterator_d.initializer,
feed_dict={features_placeholder_d: REAL, labels_placeholder_d: LABEL}) feed_dict={features_placeholder_d: REAL, labels_placeholder_d: LABEL})