bug fixes with operations

This commit is contained in:
Steve Nyemba 2020-01-03 21:47:05 -06:00
parent 65a3e84c8f
commit 6de816fc50
2 changed files with 33 additions and 16 deletions

View File

@ -14,6 +14,7 @@ import sys
from data.params import SYS_ARGS
from data.bridge import Binary
import json
import pickle
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
@ -38,7 +39,7 @@ class GNet :
self.layers.normalize = self.normalize
self.NUM_GPUS = 1
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
@ -64,8 +65,8 @@ class GNet :
self.get = void()
self.get.variables = self._variable_on_cpu
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
self.logger = args['logger'] if 'logger' in args and args['logger'] else None
self.init_logs(**args)
def init_logs(self,**args):
@ -98,7 +99,7 @@ class GNet :
def log_meta(self,**args) :
object = {
_object = {
'CONTEXT':self.CONTEXT,
'ATTRIBUTES':self.ATTRIBUTES,
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
@ -120,7 +121,8 @@ class GNet :
_name = os.sep.join([self.out_dir,'meta-'+suffix])
f = open(_name+'.json','w')
f.write(json.dumps(object))
f.write(json.dumps(_object))
return _object
def mkdir (self,path):
if not os.path.exists(path) :
os.mkdir(path)
@ -295,7 +297,7 @@ class Train (GNet):
self.column = args['column']
# print ([" *** ",self.BATCHSIZE_PER_GPU])
self.log_meta()
self.meta = self.log_meta()
def load_meta(self, column):
"""
This function will delegate the calls to load meta data to it's dependents
@ -393,7 +395,7 @@ class Train (GNet):
# saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()
init = tf.global_variables_initializer()
logs = []
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sess.run(init)
sess.run(iterator_d.initializer,
@ -415,6 +417,10 @@ class Train (GNet):
format_str = 'epoch: %d, w_distance = %f (%.1f)'
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
# print (dir (w_distance))
logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
if epoch % self.MAX_EPOCHS == 0:
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
suffix = self.get.suffix()
@ -423,6 +429,10 @@ class Train (GNet):
saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
#
#
if self.logger :
row = {"logs":logs} #,"model":pickle.dump(sess)}
self.logger.write(row=row)
class Predict(GNet):
"""

View File

@ -11,7 +11,7 @@ This package is designed to generate synthetic data from a dataset from an origi
import pandas as pd
import numpy as np
from data import gan
from transport import factory
def train (**args) :
"""
This function is intended to train the GAN in order to learn about the distribution of the features
@ -27,11 +27,18 @@ def train (**args) :
df = args['data']
logs = args['logs']
real = pd.get_dummies(df[column]).astype(np.float32).values
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
max_epochs = 10
max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
context = args['context']
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id)
if 'store' in args :
args['store']['args']['doc'] = context
logger = factory.instance(**args['store'])
else:
logger = None
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs)
return trainer.apply()
def generate(**args):