bug fix with dimensions @TODO: GPU workload
This commit is contained in:
parent
4024e508a8
commit
ce55848cc8
15
data/gan.py
15
data/gan.py
|
@ -59,20 +59,27 @@ class GNet :
|
|||
self.logs = {}
|
||||
|
||||
self.NUM_GPUS = 1 if 'num_gpu' not in args else args['num_gpu']
|
||||
|
||||
if self.NUM_GPUS > 1 :
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = "4"
|
||||
|
||||
self.X_SPACE_SIZE = args['real'].shape[1] if 'real' in args else 854
|
||||
self.G_STRUCTURE = [128,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE]
|
||||
self.D_STRUCTURE = [self.X_SPACE_SIZE,256,128] #[self.X_SPACE_SIZE, self.X_SPACE_SIZE*2, self.X_SPACE_SIZE] #-- change 854 to number of diagnosis
|
||||
# self.NUM_LABELS = 8 if 'label' not in args elif len(args['label'].shape) args['label'].shape[1]
|
||||
|
||||
if 'label' in args and len(args['label'].shape) == 2 :
|
||||
self.NUM_LABELS = args['label'].shape[1]
|
||||
elif 'label' in args and len(args['label']) == 1 :
|
||||
self.NUM_LABELS = args['label'].shape[0]
|
||||
else:
|
||||
self.NUM_LABELS = 8
|
||||
self.Z_DIM = 128 #self.X_SPACE_SIZE
|
||||
self.BATCHSIZE_PER_GPU = args['real'].shape[0] if 'real' in args else 256
|
||||
# self.Z_DIM = 128 #self.X_SPACE_SIZE
|
||||
self.Z_DIM = 128 #-- used as rows down stream
|
||||
self.G_STRUCTURE = [self.Z_DIM,self.Z_DIM]
|
||||
if 'real' in args :
|
||||
self.D_STRUCTURE = [args['real'].shape[1],256,self.Z_DIM]
|
||||
|
||||
self.BATCHSIZE_PER_GPU = int(args['real'].shape[0]* 1) if 'real' in args else 256
|
||||
self.TOTAL_BATCHSIZE = self.BATCHSIZE_PER_GPU * self.NUM_GPUS
|
||||
self.STEPS_PER_EPOCH = 256 #int(np.load('ICD9/train.npy').shape[0] / 2000)
|
||||
self.MAX_EPOCHS = 10 if 'max_epochs' not in args else int(args['max_epochs'])
|
||||
|
@ -533,6 +540,8 @@ class Predict(GNet):
|
|||
# The code below will insure we have some acceptable cardinal relationships between id and synthetic values
|
||||
#
|
||||
df = ( pd.DataFrame(np.round(f).astype(np.int32)))
|
||||
print (df.head())
|
||||
print ()
|
||||
p = 0 not in df.sum(axis=1).values
|
||||
|
||||
if p:
|
||||
|
|
2
setup.py
2
setup.py
|
@ -4,7 +4,7 @@ import sys
|
|||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
args = {"name":"data-maker","version":"1.0.8","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
|
||||
args = {"name":"data-maker","version":"1.0.9","author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vanderbilt.edu","license":"MIT",
|
||||
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
|
||||
args["install_requires"] = ['data-transport@git+https://dev.the-phi.com/git/steve/data-transport.git','tensorflow==1.15','pandas','pandas-gbq','pymongo']
|
||||
args['url'] = 'https://hiplab.mc.vanderbilt.edu/aou/data-maker.git'
|
||||
|
|
Loading…
Reference in New Issue