not sure about the changes (oops)
This commit is contained in:
parent
63a7f1a968
commit
31ca5886f0
99
data/gan.py
99
data/gan.py
|
@ -43,6 +43,10 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
|||
class void :
|
||||
pass
|
||||
class GNet :
|
||||
def log(self,**args):
|
||||
self.logs = dict(args,**self.logs)
|
||||
|
||||
|
||||
"""
|
||||
This is the base class of a generative network functions, the details will be implemented in the subclasses.
|
||||
An instance of this class is accessed as follows
|
||||
|
@ -52,7 +56,7 @@ class GNet :
|
|||
def __init__(self,**args):
|
||||
self.layers = void()
|
||||
self.layers.normalize = self.normalize
|
||||
|
||||
self.logs = {}
|
||||
|
||||
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
|
||||
|
@ -95,6 +99,15 @@ class GNet :
|
|||
|
||||
self.train_dir = os.sep.join([self.log_dir,'train',self.CONTEXT])
|
||||
self.out_dir = os.sep.join([self.log_dir,'output',self.CONTEXT])
|
||||
if self.logger :
|
||||
#
|
||||
# We will clear the logs from the data-store
|
||||
#
|
||||
column = self.ATTRIBUTES['synthetic']
|
||||
db = self.logger.db
|
||||
if db[column].count() > 0 :
|
||||
db.backup.insert({'name':column,'logs':list(db[column].find()) })
|
||||
db[column].drop()
|
||||
|
||||
def load_meta(self,column):
|
||||
"""
|
||||
|
@ -114,7 +127,9 @@ class GNet :
|
|||
|
||||
|
||||
def log_meta(self,**args) :
|
||||
|
||||
_object = {
|
||||
'_id':'meta',
|
||||
'CONTEXT':self.CONTEXT,
|
||||
'ATTRIBUTES':self.ATTRIBUTES,
|
||||
'BATCHSIZE_PER_GPU':self.BATCHSIZE_PER_GPU,
|
||||
|
@ -314,6 +329,11 @@ class Train (GNet):
|
|||
# print ([" *** ",self.BATCHSIZE_PER_GPU])
|
||||
|
||||
self.meta = self.log_meta()
|
||||
if(self.logger):
|
||||
|
||||
self.logger.write( row=self.meta )
|
||||
|
||||
self.log (real_shape=list(self._REAL.shape),label_shape = list(self._LABEL.shape),meta_data=self.meta)
|
||||
def load_meta(self, column):
|
||||
"""
|
||||
This function will delegate the calls to load meta data to it's dependents
|
||||
|
@ -350,11 +370,14 @@ class Train (GNet):
|
|||
if stage == 'D':
|
||||
w, loss = self.discriminator.loss(real=real, fake=fake, label=label)
|
||||
#losses = tf.get_collection('dlosses', scope)
|
||||
flag = 'dlosses'
|
||||
losses = tf.compat.v1.get_collection('dlosses', scope)
|
||||
else:
|
||||
w, loss = self.generator.loss(fake=fake, label=label)
|
||||
#losses = tf.get_collection('glosses', scope)
|
||||
flag = 'glosses'
|
||||
losses = tf.compat.v1.get_collection('glosses', scope)
|
||||
# losses = tf.compat.v1.get_collection(flag, scope)
|
||||
|
||||
total_loss = tf.add_n(losses, name='total_loss')
|
||||
|
||||
|
@ -369,7 +392,8 @@ class Train (GNet):
|
|||
dataset = dataset.repeat(10000)
|
||||
dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU)
|
||||
dataset = dataset.prefetch(1)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
# iterator = dataset.make_initializable_iterator()
|
||||
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
|
||||
# next_element = iterator.get_next()
|
||||
# init_op = iterator.initializer
|
||||
return iterator, features_placeholder, labels_placeholder
|
||||
|
@ -405,7 +429,10 @@ class Train (GNet):
|
|||
def apply(self,**args):
|
||||
# max_epochs = args['max_epochs'] if 'max_epochs' in args else 10
|
||||
REAL = self._REAL
|
||||
LABEL= self._LABEL
|
||||
LABEL= self._LABEL
|
||||
if (self.logger):
|
||||
pass
|
||||
|
||||
with tf.device('/cpu:0'):
|
||||
opt_d = tf.compat.v1.train.AdamOptimizer(1e-4)
|
||||
opt_g = tf.compat.v1.train.AdamOptimizer(1e-4)
|
||||
|
@ -441,7 +468,7 @@ class Train (GNet):
|
|||
print(format_str % (epoch, -w_sum/(self.STEPS_PER_EPOCH*2), duration))
|
||||
# print (dir (w_distance))
|
||||
|
||||
logs.append({"epoch":epoch,"distance":-w_sum })
|
||||
logs.append({"epoch":epoch,"distance":-w_sum/(self.STEPS_PER_EPOCH*2) })
|
||||
|
||||
if epoch % self.MAX_EPOCHS == 0:
|
||||
# suffix = "-".join(self.ATTRIBUTES['synthetic']) if isinstance(self.ATTRIBUTES['synthetic'],list) else self.ATTRIBUTES['synthetic']
|
||||
|
@ -452,9 +479,14 @@ class Train (GNet):
|
|||
#
|
||||
#
|
||||
if self.logger :
|
||||
row = {"logs":logs} #,"model":pickle.dump(sess)}
|
||||
|
||||
row = {"logs":logs} #,"model":pickle.dump(sess)}
|
||||
self.logger.write(row=row)
|
||||
#
|
||||
# @TODO:
|
||||
# We should upload the files in the checkpoint
|
||||
# This would allow the learnt model to be portable to another system
|
||||
#
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
class Predict(GNet):
|
||||
"""
|
||||
|
@ -479,38 +511,61 @@ class Predict(GNet):
|
|||
ma = [[i] for i in np.arange(self.NUM_LABELS - 2)]
|
||||
label = y[:, 1] * len(ma) + tf.squeeze(tf.matmul(y[:, 2:], tf.constant(ma, dtype=tf.int32)))
|
||||
|
||||
fake = self.generator.network(inputs=z, label=label)
|
||||
fake = self.generator.network(inputs=z, label=label)
|
||||
init = tf.compat.v1.global_variables_initializer()
|
||||
saver = tf.compat.v1.train.Saver()
|
||||
saver = tf.compat.v1.train.Saver()
|
||||
df = pd.DataFrame()
|
||||
CANDIDATE_COUNT = 1000
|
||||
NTH_VALID_CANDIDATE = count = np.random.choice(np.arange(2,60),2)[0]
|
||||
with tf.compat.v1.Session() as sess:
|
||||
|
||||
# sess.run(init)
|
||||
saver.restore(sess, model_dir)
|
||||
labels = np.zeros((self.ROW_COUNT,self.NUM_LABELS) )
|
||||
|
||||
found = []
|
||||
labels= demo
|
||||
f = sess.run(fake,feed_dict={y:labels})
|
||||
#
|
||||
# if we are dealing with numeric values only we can perform a simple marginal sum against the indexes
|
||||
#
|
||||
|
||||
df = ( pd.DataFrame(np.round(f).astype(np.int32)))
|
||||
for i in np.arange(CANDIDATE_COUNT) :
|
||||
|
||||
f = sess.run(fake,feed_dict={y:labels})
|
||||
#
|
||||
# if we are dealing with numeric values only we can perform a simple marginal sum against the indexes
|
||||
# The code below will insure we have some acceptable cardinal relationships between id and synthetic values
|
||||
#
|
||||
df = ( pd.DataFrame(np.round(f).astype(np.int32)))
|
||||
p = 0 not in df.sum(axis=1).values
|
||||
|
||||
if p:
|
||||
found.append(df)
|
||||
if len(found) == NTH_VALID_CANDIDATE or i == CANDIDATE_COUNT:
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
# i = df.T.index.astype(np.int32) #-- These are numeric pseudonyms
|
||||
# df = (i * df).sum(axis=1)
|
||||
#
|
||||
# In case we are dealing with actual values like diagnosis codes we can perform
|
||||
#
|
||||
df = found[np.random.choice(np.arange(len(found)),1)[0]]
|
||||
columns = self.ATTRIBUTES['synthetic'] if isinstance(self.ATTRIBUTES['synthetic'],list)else [self.ATTRIBUTES['synthetic']]
|
||||
|
||||
r = np.zeros((self.ROW_COUNT,len(columns)))
|
||||
for col in df :
|
||||
i = np.where(df[col])[0]
|
||||
r[i] = col
|
||||
# r = np.zeros((self.ROW_COUNT,len(columns)))
|
||||
r = np.zeros(self.ROW_COUNT)
|
||||
df.columns = self.values
|
||||
if len(found):
|
||||
print (len(found),NTH_VALID_CANDIDATE)
|
||||
# x = df * self.values
|
||||
|
||||
df = pd.DataFrame( df.apply(lambda row: self.values[np.random.choice(np.where(row != 0)[0],1)[0]] ,axis=1))
|
||||
df.columns = columns
|
||||
|
||||
|
||||
|
||||
df = pd.DataFrame(r,columns=columns)
|
||||
|
||||
df[df.columns] = (df.apply(lambda value: self.values[ int(value)],axis=1))
|
||||
return df.to_dict(orient='lists')
|
||||
|
||||
tf.compat.v1.reset_default_graph()
|
||||
|
||||
return df.to_dict(orient='list')
|
||||
# return df.to_dict(orient='list')
|
||||
# count = str(len(os.listdir(self.out_dir)))
|
||||
# _name = os.sep.join([self.out_dir,self.CONTEXT+'-'+count+'.csv'])
|
||||
|
|
|
@ -12,6 +12,7 @@ import pandas as pd
|
|||
import numpy as np
|
||||
import data.gan as gan
|
||||
from transport import factory
|
||||
import threading as thread
|
||||
def train (**args) :
|
||||
"""
|
||||
This function is intended to train the GAN in order to learn about the distribution of the features
|
||||
|
@ -21,30 +22,42 @@ def train (**args) :
|
|||
:data data-frame to be synthesized
|
||||
:context label of what we are synthesizing
|
||||
"""
|
||||
column = args['column']
|
||||
column = args['column'] if (isinstance(args['column'],list)) else [args['column']]
|
||||
|
||||
column_id = args['id']
|
||||
df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data'])
|
||||
# logs = args['logs']
|
||||
# real = pd.get_dummies(df[column]).astype(np.float32).values
|
||||
# labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
args['real'] = pd.get_dummies(df[column]).astype(np.float32).values
|
||||
df.columns = [name.lower() for name in df.columns]
|
||||
|
||||
#
|
||||
# If we have several columns we will proceed one at a time (it could be done in separate threads)
|
||||
# @TODO : Consider performing this task on several threads/GPUs simulataneously
|
||||
#
|
||||
args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
# num_gpu = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
# max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
|
||||
context = args['context']
|
||||
|
||||
if 'store' in args :
|
||||
args['store']['args']['doc'] = context
|
||||
logger = factory.instance(**args['store'])
|
||||
args['logger'] = logger
|
||||
|
||||
else:
|
||||
logger = None
|
||||
trainer = gan.Train(**args)
|
||||
# trainer = gan.Train(context=context,max_epochs=max_epochs,num_gpu=num_gpu,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs)
|
||||
return trainer.apply()
|
||||
|
||||
for col in column :
|
||||
args['real'] = pd.get_dummies(df[col]).astype(np.float32).values
|
||||
args['column'] = col
|
||||
args['context'] = col
|
||||
context = args['context']
|
||||
if 'store' in args :
|
||||
args['store']['args']['doc'] = context
|
||||
logger = factory.instance(**args['store'])
|
||||
args['logger'] = logger
|
||||
|
||||
else:
|
||||
logger = None
|
||||
trainer = gan.Train(**args)
|
||||
trainer.apply()
|
||||
def post(**args):
|
||||
"""
|
||||
This uploads the tensorflow checkpoint to a data-store (mongodb, biguqery, s3)
|
||||
|
||||
"""
|
||||
pass
|
||||
def get(**):
|
||||
"""
|
||||
This function will restore a checkpoint from a persistant storage on to disk
|
||||
"""
|
||||
pass
|
||||
def generate(**args):
|
||||
"""
|
||||
This function will generate a synthetic dataset on the basis of a model that has been learnt for the dataset
|
||||
|
@ -57,29 +70,27 @@ def generate(**args):
|
|||
"""
|
||||
# df = args['data']
|
||||
df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data'])
|
||||
column = args['column']
|
||||
|
||||
column = args['column'] if (isinstance(args['column'],list)) else [args['column']]
|
||||
column_id = args['id']
|
||||
# logs = args['logs']
|
||||
# context = args['context']
|
||||
# num_gpu = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
# max_epochs = 10 if 'max_epochs' not in args else args['max_epochs']
|
||||
|
||||
#
|
||||
#@TODO:
|
||||
# If the identifier is not present, we should fine a way to determine or make one
|
||||
#
|
||||
#ocolumns= list(set(df.columns.tolist())- set(columns))
|
||||
|
||||
values = df[column].unique().tolist()
|
||||
values.sort()
|
||||
|
||||
# labels = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values
|
||||
args['values'] = values
|
||||
# handler = gan.Predict (context=context,label=labels,max_epochs=max_epochs,num_gpu=num_gpu,values=values,column=column,logs=logs)
|
||||
handler = gan.Predict (**args)
|
||||
handler.load_meta(column)
|
||||
r = handler.apply()
|
||||
_df = df.copy()
|
||||
_df[column] = r[column]
|
||||
_df = df.copy()
|
||||
for col in column :
|
||||
args['context'] = col
|
||||
args['column'] = col
|
||||
values = df[col].unique().tolist()
|
||||
# values.sort()
|
||||
args['values'] = values
|
||||
#
|
||||
# we can determine the cardinalities here so we know what to allow or disallow
|
||||
handler = gan.Predict (**args)
|
||||
handler.load_meta(col)
|
||||
r = handler.apply()
|
||||
# print (r)
|
||||
_df[col] = r[col]
|
||||
# break
|
||||
return _df
|
|
@ -15,13 +15,15 @@ if 'config' in SYS_ARGS :
|
|||
_df = data.maker.generate(**ARGS)
|
||||
odf = pd.read_csv (ARGS['data'])
|
||||
odf.columns = [name.lower() for name in odf.columns]
|
||||
column = [ARGS['column'] ] #+ ARGS['id']
|
||||
print (column)
|
||||
print (_df[column].risk.evaluate())
|
||||
print (odf[column].risk.evaluate())
|
||||
_x = pd.get_dummies(_df[column]).values
|
||||
y = pd.get_dummies(odf[column]).values
|
||||
N = _df.shape[0]
|
||||
print (np.mean([ wd(_x[i],y[i])for i in range(0,N)]))
|
||||
column = ARGS['column'] if isinstance(ARGS['column'],list) else [ARGS['column']]
|
||||
print(pd.merge(odf,_df, on='id'))
|
||||
# print (_df[column].risk.evaluate(flag='synth'))
|
||||
# print (odf[column].risk.evaluate(flag='original'))
|
||||
# _x = pd.get_dummies(_df[column]).values
|
||||
# y = pd.get_dummies(odf[column]).values
|
||||
# N = _df.shape[0]
|
||||
# print (np.mean([ wd(_x[i],y[i])for i in range(0,N)]))
|
||||
# print (wd(_x[0],y[0]) )
|
||||
|
||||
# column = SYS_ARGS['column']
|
||||
# odf = open(SYS_ARGS['data'])
|
Loading…
Reference in New Issue