bug fixes with operations
This commit is contained in:
parent
65a3e84c8f
commit
6de816fc50
22
data/gan.py
22
data/gan.py
|
@ -14,6 +14,7 @@ import sys
|
|||
from data.params import SYS_ARGS
|
||||
from data.bridge import Binary
|
||||
import json
|
||||
import pickle
|
||||
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
||||
|
@ -38,7 +39,7 @@ class GNet :
|
|||
self.layers.normalize = self.normalize
|
||||
|
||||
|
||||
self.NUM_GPUS = 1
|
||||
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
|
||||
|
||||
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
|
||||
|
@ -64,8 +65,8 @@ class GNet :
|
|||
|
||||
self.get = void()
|
||||
self.get.variables = self._variable_on_cpu
|
||||
|
||||
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
||||
self.logger = args['logger'] if 'logger' in args and args['logger'] else None
|
||||
self.init_logs(**args)
|
||||
|
||||
def init_logs(self,**args):
|
||||
|
@ -98,7 +99,7 @@ class GNet :
|
|||
|
||||
|
||||
def log_meta(self,**args) :
|
||||
object = {
|
||||
_object = {
|
||||
'CONTEXT':self.CONTEXT,
|
||||
'ATTRIBUTES':self.ATTRIBUTES,
|
||||
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
|
||||
|
@ -120,7 +121,8 @@ class GNet :
|
|||
_name = os.sep.join([self.out_dir,'meta-'+suffix])
|
||||
|
||||
f = open(_name+'.json','w')
|
||||
f.write(json.dumps(object))
|
||||
f.write(json.dumps(_object))
|
||||
return _object
|
||||
def mkdir (self,path):
|
||||
if not os.path.exists(path) :
|
||||
os.mkdir(path)
|
||||
|
@ -295,7 +297,7 @@ class Train (GNet):
|
|||
self.column = args['column']
|
||||
# print ([" *** ",self.BATCHSIZE_PER_GPU])
|
||||
|
||||
self.log_meta()
|
||||
self.meta = self.log_meta()
|
||||
def load_meta(self, column):
|
||||
"""
|
||||
This function will delegate the calls to load meta data to it's dependents
|
||||
|
@ -393,7 +395,7 @@ class Train (GNet):
|
|||
# saver = tf.train.Saver()
|
||||
saver = tf.compat.v1.train.Saver()
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
logs = []
|
||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
|
||||
sess.run(init)
|
||||
sess.run(iterator_d.initializer,
|
||||
|
@ -415,6 +417,10 @@ class Train (GNet):
|
|||
|
||||
format_str = 'epoch: %d, w_distance = %f (%.1f)'
|
||||
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
|
||||
# print (dir (w_distance))
|
||||
|
||||
logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
|
||||
|
||||
if epoch % self.MAX_EPOCHS == 0:
|
||||
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
||||
suffix = self.get.suffix()
|
||||
|
@ -423,6 +429,10 @@ class Train (GNet):
|
|||
saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
|
||||
#
|
||||
#
|
||||
if self.logger :
|
||||
row = {"logs":logs} #,"model":pickle.dump(sess)}
|
||||
|
||||
self.logger.write(row=row)
|
||||
|
||||
class Predict(GNet):
|
||||
"""
|
||||
|
|
|
@ -11,7 +11,7 @@ This package is designed to generate synthetic data from a dataset from an origi
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
from data import gan
|
||||
|
||||
from transport import factory
|
||||
def train (**args) :
|
||||
"""
|
||||
This function is intended to train the GAN in order to learn about the distribution of the features
|
||||
|
@ -27,11 +27,18 @@ def train (**args) :
|
|||
df = args['data']
|
||||
logs = args['logs']
|
||||
real = pd.get_dummies(df[column]).astype(np.float32).values
|
||||
|
||||
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
max_epochs = 10
|
||||
|
||||
max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
|
||||
context = args['context']
|
||||
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id)
|
||||
if 'store' in args :
|
||||
args['store']['args']['doc'] = context
|
||||
logger = factory.instance(**args['store'])
|
||||
|
||||
else:
|
||||
logger = None
|
||||
|
||||
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs)
|
||||
return trainer.apply()
|
||||
|
||||
def generate(**args):
|
||||
|
|
Loading…
Reference in New Issue