bug fix with compatibility (tf 2.0)
This commit is contained in:
parent
ff076013b9
commit
8d85c0b0cc
30
data/gan.py
30
data/gan.py
|
@ -193,9 +193,11 @@ class Generator (GNet):
|
||||||
fake = args['fake']
|
fake = args['fake']
|
||||||
label = args['label']
|
label = args['label']
|
||||||
y_hat_fake = self.discriminator.network(inputs=fake, label=label)
|
y_hat_fake = self.discriminator.network(inputs=fake, label=label)
|
||||||
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
||||||
|
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
||||||
loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs)
|
loss = -tf.reduce_mean(y_hat_fake) + sum(all_regs)
|
||||||
tf.add_to_collection('glosses', loss)
|
#tf.add_to_collection('glosses', loss)
|
||||||
|
tf.compat.v1.add_to_collection('glosses', loss)
|
||||||
return loss, loss
|
return loss, loss
|
||||||
def load_meta(self, column):
|
def load_meta(self, column):
|
||||||
super().load_meta(column)
|
super().load_meta(column)
|
||||||
|
@ -281,10 +283,12 @@ class Discriminator(GNet):
|
||||||
grad = tf.gradients(y_hat, [x_hat])[0]
|
grad = tf.gradients(y_hat, [x_hat])[0]
|
||||||
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
|
slopes = tf.sqrt(tf.reduce_sum(tf.square(grad), 1))
|
||||||
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
|
gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2)
|
||||||
all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
#all_regs = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
|
||||||
|
all_regs = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
|
||||||
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
|
w_distance = -tf.reduce_mean(y_hat_real) + tf.reduce_mean(y_hat_fake)
|
||||||
loss = w_distance + 10 * gradient_penalty + sum(all_regs)
|
loss = w_distance + 10 * gradient_penalty + sum(all_regs)
|
||||||
tf.add_to_collection('dlosses', loss)
|
#tf.add_to_collection('dlosses', loss)
|
||||||
|
tf.compat.v1.add_to_collection('dlosses', loss)
|
||||||
|
|
||||||
return w_distance, loss
|
return w_distance, loss
|
||||||
class Train (GNet):
|
class Train (GNet):
|
||||||
|
@ -333,10 +337,12 @@ class Train (GNet):
|
||||||
fake = self.generator.network(inputs=z, label=label)
|
fake = self.generator.network(inputs=z, label=label)
|
||||||
if stage == 'D':
|
if stage == 'D':
|
||||||
w, loss = self.discriminator.loss(real=real, fake=fake, label=label)
|
w, loss = self.discriminator.loss(real=real, fake=fake, label=label)
|
||||||
losses = tf.get_collection('dlosses', scope)
|
#losses = tf.get_collection('dlosses', scope)
|
||||||
|
losses = tf.compat.v1.get_collection('dlosses', scope)
|
||||||
else:
|
else:
|
||||||
w, loss = self.generator.loss(fake=fake, label=label)
|
w, loss = self.generator.loss(fake=fake, label=label)
|
||||||
losses = tf.get_collection('glosses', scope)
|
#losses = tf.get_collection('glosses', scope)
|
||||||
|
losses = tf.compat.v1.get_collection('glosses', scope)
|
||||||
|
|
||||||
total_loss = tf.add_n(losses, name='total_loss')
|
total_loss = tf.add_n(losses, name='total_loss')
|
||||||
|
|
||||||
|
@ -370,8 +376,10 @@ class Train (GNet):
|
||||||
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
|
with tf.name_scope('%s_%d' % ('TOWER', i)) as scope:
|
||||||
(real, label) = iterator.get_next()
|
(real, label) = iterator.get_next()
|
||||||
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
|
loss, w = self.loss(scope=scope, stage=stage, real=self._REAL, label=self._LABEL)
|
||||||
tf.get_variable_scope().reuse_variables()
|
#tf.get_variable_scope().reuse_variables()
|
||||||
vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
tf.compat.v1.get_variable_scope().reuse_variables()
|
||||||
|
#vars_ = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
||||||
|
vars_ = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope=stage)
|
||||||
grads = opt.compute_gradients(loss, vars_)
|
grads = opt.compute_gradients(loss, vars_)
|
||||||
tower_grads.append(grads)
|
tower_grads.append(grads)
|
||||||
per_gpu_w.append(w)
|
per_gpu_w.append(w)
|
||||||
|
@ -394,9 +402,11 @@ class Train (GNet):
|
||||||
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
|
train_g, _, iterator_g, features_placeholder_g, labels_placeholder_g = self.network(stage='G', opt=opt_g)
|
||||||
# saver = tf.train.Saver()
|
# saver = tf.train.Saver()
|
||||||
saver = tf.compat.v1.train.Saver()
|
saver = tf.compat.v1.train.Saver()
|
||||||
init = tf.global_variables_initializer()
|
# init = tf.global_variables_initializer()
|
||||||
|
init = tf.compat.v1.global_variables_initializer()
|
||||||
logs = []
|
logs = []
|
||||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
#with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||||
|
with tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||||
sess.run(init)
|
sess.run(init)
|
||||||
sess.run(iterator_d.initializer,
|
sess.run(iterator_d.initializer,
|
||||||
feed_dict={features_placeholder_d: REAL, labels_placeholder_d: LABEL})
|
feed_dict={features_placeholder_d: REAL, labels_placeholder_d: LABEL})
|
||||||
|
|
Loading…
Reference in New Issue