bug fixes with operations
This commit is contained in:
parent
65a3e84c8f
commit
6de816fc50
22
data/gan.py
22
data/gan.py
|
@ -14,6 +14,7 @@ import sys
|
||||||
from data.params import SYS_ARGS
|
from data.params import SYS_ARGS
|
||||||
from data.bridge import Binary
|
from data.bridge import Binary
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
|
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
||||||
|
@ -38,7 +39,7 @@ class GNet :
|
||||||
self.layers.normalize = self.normalize
|
self.layers.normalize = self.normalize
|
||||||
|
|
||||||
|
|
||||||
self.NUM_GPUS = 1
|
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||||
|
|
||||||
|
|
||||||
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
|
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
|
||||||
|
@ -64,8 +65,8 @@ class GNet :
|
||||||
|
|
||||||
self.get = void()
|
self.get = void()
|
||||||
self.get.variables = self._variable_on_cpu
|
self.get.variables = self._variable_on_cpu
|
||||||
|
|
||||||
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
||||||
|
self.logger = args['logger'] if 'logger' in args and args['logger'] else None
|
||||||
self.init_logs(**args)
|
self.init_logs(**args)
|
||||||
|
|
||||||
def init_logs(self,**args):
|
def init_logs(self,**args):
|
||||||
|
@ -98,7 +99,7 @@ class GNet :
|
||||||
|
|
||||||
|
|
||||||
def log_meta(self,**args) :
|
def log_meta(self,**args) :
|
||||||
object = {
|
_object = {
|
||||||
'CONTEXT':self.CONTEXT,
|
'CONTEXT':self.CONTEXT,
|
||||||
'ATTRIBUTES':self.ATTRIBUTES,
|
'ATTRIBUTES':self.ATTRIBUTES,
|
||||||
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
|
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
|
||||||
|
@ -120,7 +121,8 @@ class GNet :
|
||||||
_name = os.sep.join([self.out_dir,'meta-'+suffix])
|
_name = os.sep.join([self.out_dir,'meta-'+suffix])
|
||||||
|
|
||||||
f = open(_name+'.json','w')
|
f = open(_name+'.json','w')
|
||||||
f.write(json.dumps(object))
|
f.write(json.dumps(_object))
|
||||||
|
return _object
|
||||||
def mkdir (self,path):
|
def mkdir (self,path):
|
||||||
if not os.path.exists(path) :
|
if not os.path.exists(path) :
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
|
@ -295,7 +297,7 @@ class Train (GNet):
|
||||||
self.column = args['column']
|
self.column = args['column']
|
||||||
# print ([" *** ",self.BATCHSIZE_PER_GPU])
|
# print ([" *** ",self.BATCHSIZE_PER_GPU])
|
||||||
|
|
||||||
self.log_meta()
|
self.meta = self.log_meta()
|
||||||
def load_meta(self, column):
|
def load_meta(self, column):
|
||||||
"""
|
"""
|
||||||
This function will delegate the calls to load meta data to it's dependents
|
This function will delegate the calls to load meta data to it's dependents
|
||||||
|
@ -393,7 +395,7 @@ class Train (GNet):
|
||||||
# saver = tf.train.Saver()
|
# saver = tf.train.Saver()
|
||||||
saver = tf.compat.v1.train.Saver()
|
saver = tf.compat.v1.train.Saver()
|
||||||
init = tf.global_variables_initializer()
|
init = tf.global_variables_initializer()
|
||||||
|
logs = []
|
||||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||||
sess.run(init)
|
sess.run(init)
|
||||||
sess.run(iterator_d.initializer,
|
sess.run(iterator_d.initializer,
|
||||||
|
@ -415,6 +417,10 @@ class Train (GNet):
|
||||||
|
|
||||||
format_str = 'epoch: %d, w_distance = %f (%.1f)'
|
format_str = 'epoch: %d, w_distance = %f (%.1f)'
|
||||||
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
|
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
|
||||||
|
# print (dir (w_distance))
|
||||||
|
|
||||||
|
logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
|
||||||
|
|
||||||
if epoch % self.MAX_EPOCHS == 0:
|
if epoch % self.MAX_EPOCHS == 0:
|
||||||
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
||||||
suffix = self.get.suffix()
|
suffix = self.get.suffix()
|
||||||
|
@ -423,6 +429,10 @@ class Train (GNet):
|
||||||
saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
|
saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
if self.logger :
|
||||||
|
row = {"logs":logs} #,"model":pickle.dump(sess)}
|
||||||
|
|
||||||
|
self.logger.write(row=row)
|
||||||
|
|
||||||
class Predict(GNet):
|
class Predict(GNet):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -11,7 +11,7 @@ This package is designed to generate synthetic data from a dataset from an origi
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from data import gan
|
from data import gan
|
||||||
|
from transport import factory
|
||||||
def train (**args) :
|
def train (**args) :
|
||||||
"""
|
"""
|
||||||
This function is intended to train the GAN in order to learn about the distribution of the features
|
This function is intended to train the GAN in order to learn about the distribution of the features
|
||||||
|
@ -21,17 +21,24 @@ def train (**args) :
|
||||||
:data data-frame to be synthesized
|
:data data-frame to be synthesized
|
||||||
:context label of what we are synthesizing
|
:context label of what we are synthesizing
|
||||||
"""
|
"""
|
||||||
column = args['column']
|
column = args['column']
|
||||||
|
|
||||||
column_id = args['id']
|
column_id = args['id']
|
||||||
df = args['data']
|
df = args['data']
|
||||||
logs = args['logs']
|
logs = args['logs']
|
||||||
real = pd.get_dummies(df[column]).astype(np.float32).values
|
real = pd.get_dummies(df[column]).astype(np.float32).values
|
||||||
|
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||||
|
|
||||||
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
|
||||||
max_epochs = 10
|
context = args['context']
|
||||||
context = args['context']
|
if 'store' in args :
|
||||||
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id)
|
args['store']['args']['doc'] = context
|
||||||
|
logger = factory.instance(**args['store'])
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger = None
|
||||||
|
|
||||||
|
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs)
|
||||||
return trainer.apply()
|
return trainer.apply()
|
||||||
|
|
||||||
def generate(**args):
|
def generate(**args):
|
||||||
|
|
Loading…
Reference in New Issue