gpu indexing
This commit is contained in:
parent
5a16e325ac
commit
a73e186f77
|
@ -64,7 +64,9 @@ class GNet :
|
|||
self.GPU_CHIPS = [0]
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ :
|
||||
os.environ.pop('CUDA_VISIBLE_DEVICES')
|
||||
self.NUM_GPUS = len(self.GPU_CHIPS)
|
||||
self.NUM_GPUS = 0
|
||||
else:
|
||||
self.NUM_GPUS = len(self.GPU_CHIPS)
|
||||
|
||||
self.PARTITION = args['partition']
|
||||
# if self.NUM_GPUS > 1 :
|
||||
|
|
|
@ -86,18 +86,6 @@ def train (**_args):
|
|||
:params sql
|
||||
:params store
|
||||
"""
|
||||
#
|
||||
# Let us prepare the data by calling the utility function
|
||||
#
|
||||
# if 'file' in _args :
|
||||
# #
|
||||
# # We are reading data from a file
|
||||
# _args['data'] = pd.read_csv(_args['file'])
|
||||
# else:
|
||||
# #
|
||||
# # data will be read from elsewhere (a data-store)...
|
||||
# pass
|
||||
# if 'ignore' in _args and 'columns' in _args['ignore']:
|
||||
|
||||
_inputhandler = prepare.Input(**_args)
|
||||
values,_matrix = _inputhandler.convert()
|
||||
|
@ -125,6 +113,8 @@ def train (**_args):
|
|||
args['matrix_size'] = _matrix.shape[0]
|
||||
args['batch_size'] = 2000
|
||||
args['partition'] = 0 if 'partition' not in _args else _args['partition']
|
||||
if 'gpu' in _args :
|
||||
args['gpu'] = _args['gpu']
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = str(args['gpu']) if 'gpu' in args else '0'
|
||||
|
||||
trainer = gan.Train(**args)
|
||||
|
@ -137,50 +127,7 @@ def train (**_args):
|
|||
|
||||
trainer.apply()
|
||||
pass
|
||||
def _train (**args) :
|
||||
"""
|
||||
This function is intended to train the GAN in order to learn about the distribution of the features
|
||||
:column columns that need to be synthesized (discrete)
|
||||
:logs where the output of the (location on disk)
|
||||
:id identifier of the dataset
|
||||
:data data-frame to be synthesized
|
||||
:context label of what we are synthesizing
|
||||
"""
|
||||
column = args['column'] if (isinstance(args['column'],list)) else [args['column']]
|
||||
# CONTINUOUS = args['continuous'] if 'continuous' in args else []
|
||||
# column_id = args['id']
|
||||
df = args['data'] if not isinstance(args['data'],str) else pd.read_csv(args['data'])
|
||||
df.columns = [name.lower() for name in df.columns]
|
||||
#
|
||||
# @TODO:
|
||||
# Consider sequential training of sub population for extremely large datasets
|
||||
#
|
||||
|
||||
#
|
||||
# If we have several columns we will proceed one at a time (it could be done in separate threads)
|
||||
# @TODO : Consider performing this task on several threads/GPUs simulataneously
|
||||
#
|
||||
for col in column :
|
||||
msize = args['matrix_size'] if 'matrix_size' in args else -1
|
||||
args['real'] = (Binary()).apply(df[col],msize)
|
||||
|
||||
context = args['context']
|
||||
if 'store' in args :
|
||||
args['store']['args']['doc'] = context
|
||||
logger = factory.instance(**args['store'])
|
||||
args['logger'] = logger
|
||||
info = {"rows":args['real'].shape[0],"cols":args['real'].shape[1],"name":col,"partition":args['partition']}
|
||||
logger.write({"module":"gan-train","action":"data-prep","input":info})
|
||||
|
||||
else:
|
||||
logger = None
|
||||
args['column'] = col
|
||||
args['context'] = col
|
||||
|
||||
#
|
||||
# If the s
|
||||
trainer = gan.Train(**args)
|
||||
trainer.apply()
|
||||
def get(**args):
|
||||
"""
|
||||
This function will restore a checkpoint from a persistant storage on to disk
|
||||
|
@ -214,6 +161,8 @@ def generate(**_args):
|
|||
_inputhandler = prepare.Input(**_args)
|
||||
values,_matrix = _inputhandler.convert()
|
||||
args['values'] = np.array(values)
|
||||
if 'gpu' in _args :
|
||||
args['gpu'] = _args['gpu']
|
||||
|
||||
handler = gan.Predict (**args)
|
||||
handler.load_meta(None)
|
||||
|
|
|
@ -87,6 +87,8 @@ class Components :
|
|||
_index = str(gpu[0])
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = _index
|
||||
return gpu
|
||||
else :
|
||||
return None
|
||||
def train(self,**args):
|
||||
"""
|
||||
This function will perform training on the basis of a given pointer that reads data
|
||||
|
|
Loading…
Reference in New Issue