Compare commits
2 Commits
b2cf5ead53
...
e5af702ddb
Author | SHA1 | Date |
---|---|---|
Steve Nyemba | e5af702ddb | |
Steve Nyemba | f1e2fe3699 |
16
data/gan.py
16
data/gan.py
|
@ -103,11 +103,12 @@ class GNet :
|
||||||
CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10)
|
CHECKPOINT_SKIPS = int(args['checkpoint_skips']) if 'checkpoint_skips' in args else int(self.MAX_EPOCHS/10)
|
||||||
|
|
||||||
CHECKPOINT_SKIPS = 1 if CHECKPOINT_SKIPS < 1 else CHECKPOINT_SKIPS
|
CHECKPOINT_SKIPS = 1 if CHECKPOINT_SKIPS < 1 else CHECKPOINT_SKIPS
|
||||||
|
|
||||||
# if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
|
# if self.MAX_EPOCHS < 2*CHECKPOINT_SKIPS :
|
||||||
# CHECKPOINT_SKIPS = 2
|
# CHECKPOINT_SKIPS = 2
|
||||||
# self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
|
# self.CHECKPOINTS = [1,self.MAX_EPOCHS] + np.repeat( np.divide(self.MAX_EPOCHS,CHECKPOINT_SKIPS),CHECKPOINT_SKIPS ).cumsum().astype(int).tolist()
|
||||||
self.CHECKPOINTS = np.repeat(CHECKPOINT_SKIPS, self.MAX_EPOCHS/ CHECKPOINT_SKIPS).cumsum().astype(int).tolist()
|
self.CHECKPOINTS = np.repeat(CHECKPOINT_SKIPS, self.MAX_EPOCHS/ CHECKPOINT_SKIPS).cumsum().astype(int).tolist()
|
||||||
|
|
||||||
self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100
|
self.ROW_COUNT = args['real'].shape[0] if 'real' in args else 100
|
||||||
self.CONTEXT = args['context']
|
self.CONTEXT = args['context']
|
||||||
self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None}
|
self.ATTRIBUTES = {"id":args['column_id'] if 'column_id' in args else None,"synthetic":args['column'] if 'column' in args else None}
|
||||||
|
@ -287,8 +288,17 @@ class Generator (GNet):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,**args):
|
def __init__(self,**args):
|
||||||
GNet.__init__(self,**args)
|
if 'trainer' not in args :
|
||||||
self.discriminator = Discriminator(**args)
|
GNet.__init__(self,**args)
|
||||||
|
self.discriminator = Discriminator(**args)
|
||||||
|
else:
|
||||||
|
_args = {}
|
||||||
|
_trainer = args['trainer']
|
||||||
|
for key in vars(_trainer) :
|
||||||
|
value = getattr(_trainer,key)
|
||||||
|
setattr(self,key,value)
|
||||||
|
_args[key] = value
|
||||||
|
self.discriminator = Discriminator(**_args)
|
||||||
def loss(self,**args):
|
def loss(self,**args):
|
||||||
fake = args['fake']
|
fake = args['fake']
|
||||||
label = args['label']
|
label = args['label']
|
||||||
|
|
|
@ -33,6 +33,7 @@ class Learner(Process):
|
||||||
|
|
||||||
|
|
||||||
super(Learner, self).__init__()
|
super(Learner, self).__init__()
|
||||||
|
self._arch = {'init':_args}
|
||||||
self.ndx = 0
|
self.ndx = 0
|
||||||
self._queue = Queue()
|
self._queue = Queue()
|
||||||
self.lock = RLock()
|
self.lock = RLock()
|
||||||
|
@ -44,6 +45,8 @@ class Learner(Process):
|
||||||
self.gpu = None
|
self.gpu = None
|
||||||
|
|
||||||
self.info = _args['info']
|
self.info = _args['info']
|
||||||
|
if 'context' not in self.info :
|
||||||
|
self.info['context'] = self.info['from']
|
||||||
self.columns = self.info['columns'] if 'columns' in self.info else None
|
self.columns = self.info['columns'] if 'columns' in self.info else None
|
||||||
self.store = _args['store']
|
self.store = _args['store']
|
||||||
|
|
||||||
|
@ -97,9 +100,12 @@ class Learner(Process):
|
||||||
# __info = (pd.DataFrame(self._states)[['name','path','args']]).to_dict(orient='records')
|
# __info = (pd.DataFrame(self._states)[['name','path','args']]).to_dict(orient='records')
|
||||||
if self._states :
|
if self._states :
|
||||||
__info = {}
|
__info = {}
|
||||||
|
# print (self._states)
|
||||||
for key in self._states :
|
for key in self._states :
|
||||||
__info[key] = [{"name":_item['name'],"args":_item['args'],"path":_item['path']} for _item in self._states[key]]
|
_pipeline = self._states[key]
|
||||||
|
|
||||||
|
# __info[key] = ([{'name':_payload['name']} for _payload in _pipeline])
|
||||||
|
__info[key] = [{"name":_item['name'],"args":_item['args'],"path":_item['path']} for _item in self._states[key] if _item ]
|
||||||
self.log(object='state-space',action='load',input=__info)
|
self.log(object='state-space',action='load',input=__info)
|
||||||
|
|
||||||
|
|
||||||
|
@ -270,18 +276,23 @@ class Trainer(Learner):
|
||||||
#
|
#
|
||||||
_epochs = [_e for _e in gTrain.logs['epochs'] if _e['path'] != '']
|
_epochs = [_e for _e in gTrain.logs['epochs'] if _e['path'] != '']
|
||||||
_epochs.sort(key=lambda _item: _item['loss'],reverse=False)
|
_epochs.sort(key=lambda _item: _item['loss'],reverse=False)
|
||||||
|
|
||||||
_args['network_args']['max_epochs'] = _epochs[0]['epochs']
|
_args['network_args']['max_epochs'] = _epochs[0]['epochs']
|
||||||
self.log(action='autopilot',input={'epoch':_epochs[0]})
|
self.log(action='autopilot',input={'epoch':_epochs[0]})
|
||||||
g = Generator(**_args)
|
|
||||||
# g.run()
|
# g.run()
|
||||||
|
|
||||||
end = datetime.now() #.strftime('%Y-%m-%d %H:%M:%S')
|
end = datetime.now() #.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
_min = float((end-beg).seconds/ 60)
|
_min = float((end-beg).seconds/ 60)
|
||||||
_logs = {'action':'train','input':{'start':beg.strftime('%Y-%m-%d %H:%M:%S'),'minutes':_min,"unique_counts":self._encoder._io[0]}}
|
_logs = {'action':'train','input':{'start':beg.strftime('%Y-%m-%d %H:%M:%S'),'minutes':_min,"unique_counts":self._encoder._io[0]}}
|
||||||
self.log(**_logs)
|
self.log(**_logs)
|
||||||
self._g = g
|
|
||||||
if self.autopilot :
|
if self.autopilot :
|
||||||
|
|
||||||
|
# g = Generator(**_args)
|
||||||
|
|
||||||
|
g = Generator(**self._arch['init'])
|
||||||
|
self._g = g
|
||||||
self._g.run()
|
self._g.run()
|
||||||
#
|
#
|
||||||
#@TODO Find a way to have the data in the object ....
|
#@TODO Find a way to have the data in the object ....
|
||||||
|
@ -300,10 +311,15 @@ class Generator (Learner):
|
||||||
#
|
#
|
||||||
# We need to load the mapping information for the space we are working with ...
|
# We need to load the mapping information for the space we are working with ...
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
self.network_args['candidates'] = int(_args['candidates']) if 'candidates' in _args else 1
|
self.network_args['candidates'] = int(_args['candidates']) if 'candidates' in _args else 1
|
||||||
filename = os.sep.join([self.network_args['logs'],'output',self.network_args['context'],'map.json'])
|
# filename = os.sep.join([self.network_args['logs'],'output',self.network_args['context'],'map.json'])
|
||||||
|
_suffix = self.network_args['context']
|
||||||
|
filename = os.sep.join([self.network_args['logs'],'output',self.network_args['context'],'meta-',_suffix,'.json'])
|
||||||
self.log(**{'action':'init-map','input':{'filename':filename,'exists':os.path.exists(filename)}})
|
self.log(**{'action':'init-map','input':{'filename':filename,'exists':os.path.exists(filename)}})
|
||||||
if os.path.exists(filename):
|
if os.path.exists(filename):
|
||||||
|
|
||||||
file = open(filename)
|
file = open(filename)
|
||||||
self._map = json.loads(file.read())
|
self._map = json.loads(file.read())
|
||||||
file.close()
|
file.close()
|
||||||
|
@ -580,6 +596,7 @@ class factory :
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
if _args['apply'] in [apply.RANDOM] :
|
if _args['apply'] in [apply.RANDOM] :
|
||||||
pthread = Shuffle(**_args)
|
pthread = Shuffle(**_args)
|
||||||
|
|
|
@ -69,7 +69,7 @@ class Date(Post):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
class Approximate(Post):
|
class Approximate(Post):
|
||||||
def apply(**_args):
|
def apply(**_args):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -31,12 +31,22 @@ class State :
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pointer = _item['module']
|
pointer = _item['module']
|
||||||
_args = _item['args']
|
|
||||||
|
if type(pointer).__name__ != 'function':
|
||||||
|
_args = _item['args'] if 'args' in _item else {}
|
||||||
|
else:
|
||||||
|
pointer = _item['module']
|
||||||
|
|
||||||
|
_args = _item['args'] if 'args' in _item else {}
|
||||||
|
|
||||||
|
|
||||||
_data = pointer(_data,_args)
|
_data = pointer(_data,_args)
|
||||||
return _data
|
return _data
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def instance(_args):
|
def instance(_args):
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
pre = []
|
pre = []
|
||||||
post=[]
|
post=[]
|
||||||
|
|
||||||
|
@ -45,8 +55,20 @@ class State :
|
||||||
#
|
#
|
||||||
# If the item has a path property is should be ignored
|
# If the item has a path property is should be ignored
|
||||||
path = _args[key]['path'] if 'path' in _args[key] else ''
|
path = _args[key]['path'] if 'path' in _args[key] else ''
|
||||||
out[key] = [ State._build(dict(_item,**{'path':path})) if 'path' not in _item else State._build(_item) for _item in _args[key]['pipeline']]
|
# out[key] = [ State._build(dict(_item,**{'path':path})) if 'path' not in _item else State._build(_item) for _item in _args[key]['pipeline']]
|
||||||
|
out[key] = []
|
||||||
|
for _item in _args[key]['pipeline'] :
|
||||||
|
|
||||||
|
if type(_item).__name__ == 'function':
|
||||||
|
_stageInfo = {'module':_item,'name':_item.__name__,'args':{},'path':''}
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if 'path' in _item :
|
||||||
|
_stageInfo = State._build(dict(_item,**{'path':path}))
|
||||||
|
else :
|
||||||
|
_stageInfo= State._build(_item)
|
||||||
|
out[key].append(_stageInfo)
|
||||||
|
# print ([out])
|
||||||
return out
|
return out
|
||||||
# if 'pre' in _args:
|
# if 'pre' in _args:
|
||||||
# path = _args['pre']['path'] if 'path' in _args['pre'] else ''
|
# path = _args['pre']['path'] if 'path' in _args['pre'] else ''
|
||||||
|
@ -68,11 +90,18 @@ class State :
|
||||||
pass
|
pass
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build(_args):
|
def _build(_args):
|
||||||
|
"""
|
||||||
|
This function builds the object {module,path} where module is extracted from a file (if needed)
|
||||||
|
:param _args dictionary containing attributes that can be value pair
|
||||||
|
It can also be a function
|
||||||
|
"""
|
||||||
|
#
|
||||||
|
# In the advent an actual pointer is passed we should do the following
|
||||||
|
|
||||||
_info = State._extract(_args)
|
_info = State._extract(_args)
|
||||||
# _info = dict(_args,**_info)
|
# _info = dict(_args,**_info)
|
||||||
|
|
||||||
_info['module'] = State._instance(_info)
|
_info['module'] = State._instance(_info)
|
||||||
return _info if _info['module'] is not None else None
|
return _info if _info['module'] is not None else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -4,7 +4,7 @@ import sys
|
||||||
|
|
||||||
def read(fname):
|
def read(fname):
|
||||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||||
args = {"name":"data-maker","version":"1.6.4",
|
args = {"name":"data-maker","version":"1.6.6",
|
||||||
"author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vumc.org","license":"MIT",
|
"author":"Vanderbilt University Medical Center","author_email":"steve.l.nyemba@vumc.org","license":"MIT",
|
||||||
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
|
"packages":find_packages(),"keywords":["healthcare","data","transport","protocol"]}
|
||||||
args["install_requires"] = ['data-transport@git+https://github.com/lnyemba/data-transport.git','tensorflow']
|
args["install_requires"] = ['data-transport@git+https://github.com/lnyemba/data-transport.git','tensorflow']
|
||||||
|
|
Loading…
Reference in New Issue