From e02a4a60abd8a936d77b5720beb0e27a34718307 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Fri, 6 Mar 2020 15:26:18 -0600 Subject: [PATCH] acceptance criteria fix --- data/gan.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/data/gan.py b/data/gan.py index a591f34..80c3f8e 100644 --- a/data/gan.py +++ b/data/gan.py @@ -584,7 +584,7 @@ class Predict(GNet): p = 0 not in df.sum(axis=1).values x = df.sum(axis=1).values - if np.divide( np.sum(x), x.size) > .9 or p and np.sum(x) == x.size: + if x.max() == 1 and np.divide( np.sum(x), x.size) > .9 or p and np.sum(x) == x.size and x.size == self.values.size: ratio.append(np.divide( np.sum(x), x.size)) found.append(df) if i == CANDIDATE_COUNT: @@ -606,7 +606,9 @@ class Predict(GNet): # r = np.zeros((self.ROW_COUNT,len(columns))) # r = np.zeros(self.ROW_COUNT) - + if self.logger : + info = {"found":len(found),"selected":INDEX, "ratio": ratio[INDEX],"rows":df.shape[0],"cols":df.shape[1]} + self.logger.write({"module":"gan-generate","action":"generate","input":info}) df.columns = self.values if len(found): # print (len(found),NTH_VALID_CANDIDATE)