From 553ee75a0681a80ea95fd0bdcf7920d383510c8d Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Tue, 25 Feb 2020 11:41:40 -0600 Subject: [PATCH] bug fix around shape of candidate data to generate --- data/gan.py | 49 +++++++++++++++++++++++++++++++----------- data/maker/__init__.py | 8 +++---- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/data/gan.py b/data/gan.py index 4c05566..6e6454e 100644 --- a/data/gan.py +++ b/data/gan.py @@ -166,7 +166,15 @@ class GNet : return _object def mkdir (self,path): if not os.path.exists(path) : - os.mkdir(path) + if os.sep in path : + pass + root = [] + for loc in path.split(os.sep) : + root.append(loc) + os.mkdir(os.sep.join(root)) + + else: + os.mkdir(path) def normalize(self,**args): @@ -520,8 +528,10 @@ class Predict(GNet): """ def __init__(self,**args): GNet.__init__(self,**args) - self.generator = Generator(**args) - self.values = args['values'] + self.generator = Generator(**args) + self.values = args['values'] + self.ROW_COUNT = args['row_count'] + self.MISSING_VALUES = args['no_value'] def load_meta(self, column): super().load_meta(column) self.generator.load_meta(column) @@ -532,8 +542,8 @@ class Predict(GNet): model_dir = os.sep.join([self.train_dir,suffix+'-'+str(self.MAX_EPOCHS)]) demo = self._LABEL #np.zeros([self.ROW_COUNT,self.NUM_LABELS]) #args['de"shape":{"LABEL":list(self._LABEL.shape)} mo'] tf.compat.v1.reset_default_graph() - z = tf.random.normal(shape=[self.BATCHSIZE_PER_GPU, self.Z_DIM]) - y = tf.compat.v1.placeholder(shape=[self.BATCHSIZE_PER_GPU, self.NUM_LABELS], dtype=tf.int32) + z = tf.random.normal(shape=[self.ROW_COUNT, self.Z_DIM]) + y = tf.compat.v1.placeholder(shape=[self.ROW_COUNT, self.NUM_LABELS], dtype=tf.int32) if self._LABEL is not None : ma = [[i] for i in np.arange(self.NUM_LABELS - 2)] label = y[:, 1] * len(ma) + tf.squeeze(tf.matmul(y[:, 2:], tf.constant(ma, dtype=tf.int32))) @@ -556,7 +566,7 @@ class Predict(GNet): labels = None found = [] - + ratio = [] for i in np.arange(CANDIDATE_COUNT) : if labels : f = sess.run(fake,feed_dict={y:labels}) @@ -569,10 +579,11 @@ class Predict(GNet): df = ( pd.DataFrame(np.round(f).astype(np.int32))) p = 0 not in df.sum(axis=1).values x = df.sum(axis=1).values - print ( [np.sum(x),x.size]) - if np.divide( np.sum(x), x.size) : + + if np.divide( np.sum(x), x.size) > .9 or p: + ratio.append(np.divide( np.sum(x), x.size)) found.append(df) - if len(found) == NTH_VALID_CANDIDATE or i == CANDIDATE_COUNT: + if i == CANDIDATE_COUNT: break else: continue @@ -582,8 +593,9 @@ class Predict(GNet): # # In case we are dealing with actual values like diagnosis codes we can perform # - - df = found[np.random.choice(np.arange(len(found)),1)[0]] + INDEX = np.random.choice(np.arange(len(found)),1)[0] + INDEX = ratio.index(np.max(ratio)) + df = found[INDEX] columns = self.ATTRIBUTES['synthetic'] if isinstance(self.ATTRIBUTES['synthetic'],list)else [self.ATTRIBUTES['synthetic']] # r = np.zeros((self.ROW_COUNT,len(columns))) @@ -592,9 +604,20 @@ class Predict(GNet): if len(found): print (len(found),NTH_VALID_CANDIDATE) # x = df * self.values - - df = pd.DataFrame( df.apply(lambda row: self.values[np.random.choice(np.where(row != 0)[0],1)[0]] ,axis=1)) + # + # let's get the missing rows (if any) ... + # + ii = df.apply(lambda row: np.sum(row) == 0 ,axis=1) + if ii : + # + #@TODO Have this be a configurable variable + missing = np.repeat(0, np.where(ii==1)[0].size) + else: + missing = [] + i = np.where(ii == 0)[0] + df = pd.DataFrame( df.iloc.apply(lambda row: self.values[np.random.choice(np.where(row != 0)[0],1)[0]] ,axis=1)) df.columns = columns + df = df[columns[0]].append(pd.Series(missing)) diff --git a/data/maker/__init__.py b/data/maker/__init__.py index 3c04b57..6205b78 100644 --- a/data/maker/__init__.py +++ b/data/maker/__init__.py @@ -77,25 +77,23 @@ def generate(**args): df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data']) column = args['column'] if (isinstance(args['column'],list)) else [args['column']] - column_id = args['id'] + # column_id = args['id'] # #@TODO: # If the identifier is not present, we should fine a way to determine or make one # - # args['label'] = pd.get_dummies(df[column_id]).astype(np.float32).values - bwrangler = Binary() - # args['label'] = bwrangler.Export(df[[column_id]]) _df = df.copy() for col in column : args['context'] = col args['column'] = col values = df[col].unique().tolist() - # values.sort() args['values'] = values + args['row_count'] = df.shape[0] # # we can determine the cardinalities here so we know what to allow or disallow handler = gan.Predict (**args) handler.load_meta(col) + # handler.ROW_COUNT = df[col].shape[0] r = handler.apply() # print (r) _df[col] = r[col]