bug fixes
This commit is contained in:
parent
f85d5fd870
commit
dcc55eb1fb
25
data/gan.py
25
data/gan.py
|
@ -1,8 +1,23 @@
|
|||
"""
|
||||
usage :
|
||||
optional :
|
||||
--num_gpu number of gpus to use will default to 1
|
||||
--epoch steps per epoch default to 256
|
||||
This code was originally writen by Ziqi Zhang <ziqi.zhang@vanderbilt.edu> in order to generate synthetic data.
|
||||
The code is an implementation of a Generative Adversarial Network that uses the Wasserstein Distance (WGAN).
|
||||
It is intended to be used in 2 modes (embedded in code or using CLI)
|
||||
|
||||
USAGE :
|
||||
|
||||
The following parameters should be provided in a configuration file (JSON format)
|
||||
python data/maker --config <path-to-config-file.json>
|
||||
|
||||
CONFIGURATION FILE STRUCTURE :
|
||||
|
||||
context what it is you are loading (stroke, hypertension, ...)
|
||||
data path of the file to be loaded
|
||||
logs folder to store training model and meta data about learning
|
||||
max_epochs number of iterations in learning
|
||||
num_gpu number of gpus to be used (will still run if the GPUs are not available)
|
||||
|
||||
EMBEDDED IN CODE :
|
||||
|
||||
"""
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.layers import l2_regularizer
|
||||
|
@ -426,7 +441,7 @@ class Train (GNet):
|
|||
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
|
||||
# print (dir (w_distance))
|
||||
|
||||
logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
|
||||
logs.append({"epoch":epoch,"distance":-w_sum })
|
||||
|
||||
if epoch % self.MAX_EPOCHS == 0:
|
||||
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
||||
|
|
|
@ -24,21 +24,25 @@ def train (**args) :
|
|||
column = args['column']
|
||||
|
||||
column_id = args['id']
|
||||
df = args['data']
|
||||
logs = args['logs']
|
||||
real = pd.get_dummies(df[column]).astype(np.float32).values
|
||||
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
num_gpu = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
|
||||
df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data'])
|
||||
# logs = args['logs']
|
||||
# real = pd.get_dummies(df[column]).astype(np.float32).values
|
||||
# labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
args['real'] = pd.get_dummies(df[column]).astype(np.float32).values
|
||||
args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
# num_gpu = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
# max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
|
||||
context = args['context']
|
||||
|
||||
if 'store' in args :
|
||||
args['store']['args']['doc'] = context
|
||||
logger = factory.instance(**args['store'])
|
||||
args['logger'] = logger
|
||||
|
||||
else:
|
||||
logger = None
|
||||
|
||||
trainer = gan.Train(context=context,max_epochs=max_epochs,num_gpu=num_gpu,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs)
|
||||
trainer = gan.Train(**args)
|
||||
# trainer = gan.Train(context=context,max_epochs=max_epochs,num_gpu=num_gpu,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs)
|
||||
return trainer.apply()
|
||||
|
||||
def generate(**args):
|
||||
|
@ -51,14 +55,14 @@ def generate(**args):
|
|||
:id column identifying an entity
|
||||
:logs location on disk where the learnt knowledge of the dataset is
|
||||
"""
|
||||
df = args['data']
|
||||
|
||||
# df = args['data']
|
||||
df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data'])
|
||||
column = args['column']
|
||||
column_id = args['id']
|
||||
logs = args['logs']
|
||||
context = args['context']
|
||||
num_gpu = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
|
||||
# logs = args['logs']
|
||||
# context = args['context']
|
||||
# num_gpu = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
# max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
|
||||
|
||||
#
|
||||
#@TODO:
|
||||
|
@ -69,8 +73,11 @@ def generate(**args):
|
|||
values = df[column].unique().tolist()
|
||||
values.sort()
|
||||
|
||||
labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
handler = gan.Predict (context=context,label=labels,max_epochs=max_epochs,num_gpu=num_gpu,values=values,column=column,logs=logs)
|
||||
# labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
args['values'] = values
|
||||
# handler = gan.Predict (context=context,label=labels,max_epochs=max_epochs,num_gpu=num_gpu,values=values,column=column,logs=logs)
|
||||
handler = gan.Predict (**args)
|
||||
handler.load_meta(column)
|
||||
r = handler.apply()
|
||||
_df = df.copy()
|
||||
|
|
|
@ -1,10 +1,27 @@
|
|||
import pandas as pd
|
||||
import data.maker
|
||||
|
||||
df = pd.read_csv('sample.csv')
|
||||
column = 'gender'
|
||||
id = 'id'
|
||||
context = 'demo'
|
||||
store = {"type":"mongo.MongoWriter","args":{"host":"localhost:27017","dbname":"GAN"}}
|
||||
max_epochs = 11
|
||||
data.maker.train(store=store,max_epochs=max_epochs,context=context,data=df,column=column,id=id,logs='foo')
|
||||
from data.params import SYS_ARGS
|
||||
import json
|
||||
from scipy.stats import wasserstein_distance as wd
|
||||
import risk
|
||||
import numpy as np
|
||||
if 'config' in SYS_ARGS :
|
||||
ARGS = json.loads(open(SYS_ARGS['config']).read())
|
||||
if 'generate' not in SYS_ARGS :
|
||||
data.maker.train(**ARGS)
|
||||
else:
|
||||
#
|
||||
#
|
||||
_df = data.maker.generate(**ARGS)
|
||||
odf = pd.read_csv (ARGS['data'])
|
||||
odf.columns = [name.lower() for name in odf.columns]
|
||||
column = [ARGS['column'] ] #+ ARGS['id']
|
||||
print (column)
|
||||
print (_df[column].risk.evaluate())
|
||||
print (odf[column].risk.evaluate())
|
||||
_x = pd.get_dummies(_df[column]).values
|
||||
y = pd.get_dummies(odf[column]).values
|
||||
N = _df.shape[0]
|
||||
print (np.mean([ wd(_x[i],y[i])for i in range(0,N)]))
|
||||
# column = SYS_ARGS['column']
|
||||
# odf = open(SYS_ARGS['data'])
|
Loading…
Reference in New Issue