fixes with the framework - only supports single feature

This commit is contained in:
Steve Nyemba 2019-12-31 23:27:53 -06:00
parent 7f3748121c
commit 685c567661
5 changed files with 125 additions and 30 deletions

View File

@ -15,22 +15,29 @@ After installing the easiest way to get started is as follows (using pandas). Th
1. Train the GAN on the original/raw dataset 1. Train the GAN on the original/raw dataset
import pandas as pd import pandas as pd
import data.maker import data.maker
df = pd.read_csv('myfile.csv') df = pd.read_csv('sample.csv')
cols= ['f1','f2','f2'] column = 'gender'
data.maker.train(data=df,cols=cols,logs='logs') id = 'id'
context = 'demo'
data.maker.train(context=context,data=df,column=column,id=id,logs='logs')
The trainer will store the data on disk (for now) in a structured folder that will hold training models that will be used to generate the synthetic data.
2. Generate a candidate dataset from the learnt features 2. Generate a candidate dataset from the learnt features
import pandas as pd import pandas as pd
import data.maker import data.maker
df = data.maker.generate(logs='logs')
df.head()
df = pd.read_csv('sample.csv')
id = 'id'
column = 'gender'
context = 'demo'
data.maker.generate(data=df,id=id,column=column,logs='logs')
## Limitations ## Limitations

1
data/__init__.py Normal file
View File

@ -0,0 +1 @@
import data.params as params

View File

@ -11,8 +11,8 @@ import pandas as pd
import time import time
import os import os
import sys import sys
from params import SYS_ARGS from data.params import SYS_ARGS
from bridge import Binary from data.bridge import Binary
import json import json
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@ -37,8 +37,6 @@ class GNet :
self.layers = void() self.layers = void()
self.layers.normalize = self.normalize self.layers.normalize = self.normalize
self.get = void()
self.get.variables = self._variable_on_cpu
self.NUM_GPUS = 1 self.NUM_GPUS = 1
@ -63,7 +61,11 @@ class GNet :
self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None} self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None}
self._REAL = args['real'] if 'real' in args else None self._REAL = args['real'] if 'real' in args else None
self._LABEL = args['label'] if 'label' in args else None self._LABEL = args['label'] if 'label' in args else None
self.get = void()
self.get.variables = self._variable_on_cpu
self.get.suffix = lambda : "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
self.init_logs(**args) self.init_logs(**args)
def init_logs(self,**args): def init_logs(self,**args):
@ -83,7 +85,9 @@ class GNet :
This function is designed to accomodate the uses of the sub-classes outside of a strict dependency model. This function is designed to accomodate the uses of the sub-classes outside of a strict dependency model.
Because prediction and training can happen independently Because prediction and training can happen independently
""" """
_name = os.sep.join([self.out_dir,'meta-'+column+'.json']) # suffix = "-".join(column) if isinstance(column,list)else column
suffix = self.get.suffix()
_name = os.sep.join([self.out_dir,'meta-'+suffix+'.json'])
if os.path.exists(_name) : if os.path.exists(_name) :
attr = json.loads((open(_name)).read()) attr = json.loads((open(_name)).read())
for key in attr : for key in attr :
@ -111,7 +115,10 @@ class GNet :
key = args['key'] key = args['key']
value= args['value'] value= args['value']
object[key] = value object[key] = value
_name = os.sep.join([self.out_dir,'meta-'+SYS_ARGS['column']]) # suffix = "-".join(self.column) if isinstance(self.column,list) else self.column
suffix = self.get.suffix()
_name = os.sep.join([self.out_dir,'meta-'+suffix])
f = open(_name+'.json','w') f = open(_name+'.json','w')
f.write(json.dumps(object)) f.write(json.dumps(object))
def mkdir (self,path): def mkdir (self,path):
@ -285,7 +292,9 @@ class Train (GNet):
self.discriminator = Discriminator(**args) self.discriminator = Discriminator(**args)
self._REAL = args['real'] self._REAL = args['real']
self._LABEL= args['label'] self._LABEL= args['label']
self.column = args['column']
# print ([" *** ",self.BATCHSIZE_PER_GPU]) # print ([" *** ",self.BATCHSIZE_PER_GPU])
self.log_meta() self.log_meta()
def load_meta(self, column): def load_meta(self, column):
""" """
@ -407,8 +416,9 @@ class Train (GNet):
format_str = 'epoch: %d, w_distance = %f (%.1f)' format_str = 'epoch: %d, w_distance = %f (%.1f)'
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration)) print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
if epoch % self.MAX_EPOCHS == 0: if epoch % self.MAX_EPOCHS == 0:
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
_name = os.sep.join([self.train_dir,self.ATTRIBUTES['synthetic']]) suffix = self.get.suffix()
_name = os.sep.join([self.train_dir,suffix])
# saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch) # saver.save(sess, self.train_dir, write_meta_graph=False, global_step=epoch)
saver.save(sess, _name, write_meta_graph=False, global_step=epoch) saver.save(sess, _name, write_meta_graph=False, global_step=epoch)
# #
@ -420,14 +430,16 @@ class Predict(GNet):
""" """
def __init__(self,**args): def __init__(self,**args):
GNet.__init__(self,**args) GNet.__init__(self,**args)
self.generator = Generator(**args) self.generator = Generator(**args)
self.values = values self.values = args['values']
def load_meta(self, column): def load_meta(self, column):
super().load_meta(column) super().load_meta(column)
self.generator.load_meta(column) self.generator.load_meta(column)
def apply(self,**args): def apply(self,**args):
# print (self.train_dir) # print (self.train_dir)
model_dir = os.sep.join([self.train_dir,self.ATTRIBUTES['synthetic']+'-'+str(self.MAX_EPOCHS)]) # suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
suffix = self.get.suffix()
model_dir = os.sep.join([self.train_dir,suffix+'-'+str(self.MAX_EPOCHS)])
demo = self._LABEL #np.zeros([self.ROW_COUNT,self.NUM_LABELS]) #args['de"shape":{"LABEL":list(self._LABEL.shape)} mo'] demo = self._LABEL #np.zeros([self.ROW_COUNT,self.NUM_LABELS]) #args['de"shape":{"LABEL":list(self._LABEL.shape)} mo']
tf.compat.v1.reset_default_graph() tf.compat.v1.reset_default_graph()
z = tf.random.normal(shape=[self.BATCHSIZE_PER_GPU, self.Z_DIM]) z = tf.random.normal(shape=[self.BATCHSIZE_PER_GPU, self.Z_DIM])
@ -450,19 +462,24 @@ class Predict(GNet):
# if we are dealing with numeric values only we can perform a simple marginal sum against the indexes # if we are dealing with numeric values only we can perform a simple marginal sum against the indexes
# #
df = ( pd.DataFrame(np.round(f).astype(np.int32),columns=values)) df = ( pd.DataFrame(np.round(f).astype(np.int32)))
# i = df.T.index.astype(np.int32) #-- These are numeric pseudonyms # i = df.T.index.astype(np.int32) #-- These are numeric pseudonyms
# df = (i * df).sum(axis=1) # df = (i * df).sum(axis=1)
# #
# In case we are dealing with actual values like diagnosis codes we can perform # In case we are dealing with actual values like diagnosis codes we can perform
# #
r = np.zeros((self.ROW_COUNT,1)) columns = self.ATTRIBUTES['synthetic'] if isinstance(self.ATTRIBUTES['synthetic'],list)else [self.ATTRIBUTES['synthetic']]
r = np.zeros((self.ROW_COUNT,len(columns)))
for col in df : for col in df :
i = np.where(df[col])[0] i = np.where(df[col])[0]
r[i] = col r[i] = col
df = pd.DataFrame(r,columns=[self.ATTRIBUTES['synthetic']])
return df.to_dict(orient='list') df = pd.DataFrame(r,columns=columns)
df[df.columns] = (df.apply(lambda value: self.values[ int(value)],axis=1))
return df.to_dict(orient='lists')
# return df.to_dict(orient='list')
# count = str(len(os.listdir(self.out_dir))) # count = str(len(os.listdir(self.out_dir)))
# _name = os.sep.join([self.out_dir,self.CONTEXT+'-'+count+'.csv']) # _name = os.sep.join([self.out_dir,self.CONTEXT+'-'+count+'.csv'])
# df.to_csv(_name,index=False) # df.to_csv(_name,index=False)
@ -476,7 +493,7 @@ class Predict(GNet):
# idx2 = (demo[:, n] == 1) # idx2 = (demo[:, n] == 1)
# idx = [idx1[j] and idx2[j] for j in range(len(idx1))] # idx = [idx1[j] and idx2[j] for j in range(len(idx1))]
# num = np.sum(idx) # num = np.sum(idx)
# print ("_____________________") # print ("___________________list__")
# print (idx1) # print (idx1)
# print (idx2) # print (idx2)
# print (idx) # print (idx)
@ -531,7 +548,8 @@ if __name__ == '__main__' :
elif 'generate' in SYS_ARGS: elif 'generate' in SYS_ARGS:
values = df[column].unique().tolist() values = df[column].unique().tolist()
values.sort() values.sort()
p = Predict(context=context,label=LABEL,values=values)
p = Predict(context=context,label=LABEL,values=values,column=column)
p.load_meta(column) p.load_meta(column)
r = p.apply() r = p.apply()
print (df) print (df)
@ -539,6 +557,7 @@ if __name__ == '__main__' :
df[column] = r[column] df[column] = r[column]
print (df) print (df)
else: else:
print (SYS_ARGS.keys()) print (SYS_ARGS.keys())
print (__doc__) print (__doc__)

68
data/maker/__init__.py Normal file
View File

@ -0,0 +1,68 @@
"""
(c) 2019 Data Maker, hiplab.mc.vanderbilt.edu
version 1.0.0
This package serves as a proxy to the overall usage of the framework.
This package is designed to generate synthetic data from a dataset from an original dataset using deep learning techniques
@TODO:
- Make configurable GPU, EPOCHS
"""
import pandas as pd
import numpy as np
from data import gan
def train (**args) :
"""
This function is intended to train the GAN in order to learn about the distribution of the features
:column columns that need to be synthesized (discrete)
:logs where the output of the (location on disk)
:id identifier of the dataset
:data data-frame to be synthesized
:context label of what we are synthesizing
"""
column = args['column']
column_id = args['id']
df = args['data']
logs = args['logs']
real = pd.get_dummies(df[column]).astype(np.float32).values
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
max_epochs = 10
context = args['context']
trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id)
return trainer.apply()
def generate(**args):
"""
This function will generate a synthetic dataset on the basis of a model that has been learnt for the dataset
@return pandas.DataFrame
:data data-frame to be synthesized
:column columns that need to be synthesized (discrete)
:id column identifying an entity
:logs location on disk where the learnt knowledge of the dataset is
"""
df = args['data']
column = args['column']
column_id = args['id']
logs = args['logs']
context = args['context']
#
#@TODO:
# If the identifier is not present, we should fine a way to determine or make one
#
#ocolumns= list(set(df.columns.tolist())- set(columns))
values = df[column].unique().tolist()
values.sort()
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
handler = gan.Predict (context=context,label=labels,values=values,column=column)
handler.load_meta(column)
r = handler.apply()
_df = df.copy()
_df[column] = r[column]
return _df

View File

@ -5,7 +5,7 @@ import sys
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
args = {"name":"data-maker","version":"1.0.0","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT", args = {"name":"data-maker","version":"1.0.0","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
"packages":["data-maker"],"keywords":["healthcare","edi","x12","data","transport","protocol"]} "packages":["data-maker"],"keywords":["healthcare","data","transport","protocol"]}
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','numpy','pandas','pandas-gbq','pymongo'] args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','numpy','pandas','pandas-gbq','pymongo']
args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/gan.git' args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/gan.git'
if sys.version_info[0] == 2 : if sys.version_info[0] == 2 :