bug fixes with operations

This commit is contained in:
Steve Nyemba 2020-01-03 21:47:05 -06:00
parent 65a3e84c8f
commit 6de816fc50
2 changed files with 33 additions and 16 deletions

View File

@ -14,6 +14,7 @@ import sys
from data.params import SYS_ARGS from data.params import SYS_ARGS
from data.bridge import Binary from data.bridge import Binary
import json import json
import pickle
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "0" os.environ['CUDA_VISIBLE_DEVICES'] = "0"
@ -38,7 +39,7 @@ class GNet :
self.layers.normalize = self.normalize self.layers.normalize = self.normalize
self.NUM_GPUS = 1 self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854 self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
@ -64,8 +65,8 @@ class GNet :
self.get = void() self.get = void()
self.get.variables = self._variable_on_cpu self.get.variables = self._variable_on_cpu
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
self.logger = args['logger'] if 'logger' in args and args['logger'] else None
self.init_logs(**args) self.init_logs(**args)
def init_logs(self,**args): def init_logs(self,**args):
@ -98,7 +99,7 @@ class GNet :
def log_meta(self,**args) : def log_meta(self,**args) :
object = { _object = {
'CONTEXT':self.CONTEXT, 'CONTEXT':self.CONTEXT,
'ATTRIBUTES':self.ATTRIBUTES, 'ATTRIBUTES':self.ATTRIBUTES,
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU, 'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
@ -120,7 +121,8 @@ class GNet :
_name = os.sep.join([self.out_dir,'meta-'+suffix]) _name = os.sep.join([self.out_dir,'meta-'+suffix])
f = open(_name+'.json','w') f = open(_name+'.json','w')
f.write(json.dumps(object)) f.write(json.dumps(_object))
return _object
def mkdir (self,path): def mkdir (self,path):
if not os.path.exists(path) : if not os.path.exists(path) :
os.mkdir(path) os.mkdir(path)
@ -295,7 +297,7 @@ class Train (GNet):
self.column = args['column'] self.column = args['column']
# print ([" *** ",self.BATCHSIZE_PER_GPU]) # print ([" *** ",self.BATCHSIZE_PER_GPU])
self.log_meta() self.meta = self.log_meta()
def load_meta(self, column): def load_meta(self, column):
""" """
This function will delegate the calls to load meta data to it's dependents This function will delegate the calls to load meta data to it's dependents
@ -393,7 +395,7 @@ class Train (GNet):
# saver = tf.train.Saver() # saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver() saver = tf.compat.v1.train.Saver()
init = tf.global_variables_initializer() init = tf.global_variables_initializer()
logs = []
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess: with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(init) sess.run(init)
sess.run(iterator_d.initializer, sess.run(iterator_d.initializer,
@ -415,6 +417,10 @@ class Train (GNet):
format_str = 'epoch: %d, w_distance = %f (%.1f)' format_str = 'epoch: %d, w_distance = %f (%.1f)'
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
# print (dir (w_distance))
logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
if epoch % self.MAX_EPOCHS == 0: if epoch % self.MAX_EPOCHS == 0:
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic'] # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
suffix = self.get.suffix() suffix = self.get.suffix()
@ -423,6 +429,10 @@ class Train (GNet):
saver.save(sess, _name, write_meta_graph=False, global_step=epoch) saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
# #
# #
if self.logger :
row = {"logs":logs} #,"model":pickle.dump(sess)}
self.logger.write(row=row)
class Predict(GNet): class Predict(GNet):
""" """

View File

@ -11,7 +11,7 @@ This package is designed to generate synthetic data from a dataset from an origi
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from data import gan from data import gan
from transport import factory
def train (**args) : def train (**args) :
""" """
This function is intended to train the GAN in order to learn about the distribution of the features This function is intended to train the GAN in order to learn about the distribution of the features
@ -21,17 +21,24 @@ def train (**args) :
:data data-frame to be synthesized :data data-frame to be synthesized
:context label of what we are synthesizing :context label of what we are synthesizing
""" """
column = args['column'] column = args['column']
column_id = args['id'] column_id = args['id']
df = args['data'] df = args['data']
logs = args['logs'] logs = args['logs']
real = pd.get_dummies(df[column]).astype(np.float32).values real = pd.get_dummies(df[column]).astype(np.float32).values
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
labels = pd.get_dummies(df[column_id]).astype(np.float32).values max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
max_epochs = 10 context = args['context']
context = args['context'] if 'store' in args :
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id) args['store']['args']['doc'] = context
logger = factory.instance(**args['store'])
else:
logger = None
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs)
return trainer.apply() return trainer.apply()
def generate(**args): def generate(**args):