From 49177957b8e5f7f96621ec2e261f9a59bb2b815a Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Fri, 6 Mar 2020 14:56:28 -0600 Subject: [PATCH] ... --- pipeline.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pipeline.py b/pipeline.py index 2ce90a9..9d9c097 100644 --- a/pipeline.py +++ b/pipeline.py @@ -78,6 +78,7 @@ class Components : df = np.array_split(df[columns].values,PART_SIZE) qwriter = factory.instance(type='queue.QueueWriter',args={'queue':'aou.io'}) part_index = 0 + # # let's start n processes to listen & train this mother ... # @@ -145,7 +146,7 @@ class Components : _args['max_epochs'] = 150 if 'max_epochs' not in args else int(args['max_epochs']) # _args['num_gpu'] = int(args['num_gpu']) if 'num_gpu' in args else 1 - if args['num_gpu'] > 1 : + if int(args['num_gpu']) > 1 : _args['gpu'] = int(args['gpu']) if int(args['gpu']) < 8 else np.random.choice(np.arange(8)).astype(int)[0] else: _args['gpu'] = 0 @@ -295,10 +296,17 @@ if __name__ == '__main__' : del args['reader'] columns = DATA.columns.tolist() DATA = np.array_split(DATA[args['columns']],len(content)) + for id in ''.join(content) : + if 'focus' in args and int(args['focus']) != int(id) : + # + # This handles failures/recoveries for whatever reason + # If we are only interested in generating data for a given partition + continue + args['partition'] = id args['data'] = pd.DataFrame(DATA[(int(id))],columns=args['columns']) - if args['num_gpu'] > 1 : + if int(args['num_gpu']) > 1 : args['gpu'] = id else: args['gpu']=0