From af6ab356d832014d9608ff70812d47e07b24aa53 Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Mon, 16 Mar 2020 16:22:34 -0500 Subject: [PATCH] bug fix: index number or context --- data/gan.py | 2 +- pipeline.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/data/gan.py b/data/gan.py index 28d5ea3..c85776a 100644 --- a/data/gan.py +++ b/data/gan.py @@ -424,7 +424,7 @@ class Train (GNet): dataset = tf.data.Dataset.from_tensor_slices(features_placeholder) # labels_placeholder = None dataset = dataset.repeat(10000) - print ([' ******* ',self.BATCHSIZE_PER_GPU]) + dataset = dataset.batch(batch_size=self.BATCHSIZE_PER_GPU) dataset = dataset.prefetch(1) # iterator = dataset.make_initializable_iterator() diff --git a/pipeline.py b/pipeline.py index bfdd72e..b838043 100644 --- a/pipeline.py +++ b/pipeline.py @@ -244,8 +244,19 @@ if __name__ == '__main__' : f = open (filename) PIPELINE = json.loads(f.read()) f.close() - index = int(SYS_ARGS['index']) if 'index' in SYS_ARGS else 0 - + index = SYS_ARGS['index'] + if index.isnumeric() : + index = int(SYS_ARGS['index']) + else: + # + # The index provided is a key to a pipeline entry mainly the context + # + N = len(PIPELINE) + f = [i for i in range(0,N) if PIPELINE[i]['context'] == index] + index = f[0] if f else 0 + # + # print + print ("..::: ",PIPELINE[index]['context']) args = (PIPELINE[index]) args = dict(args,**SYS_ARGS)