""" (c) 2019 Data Maker, hiplab.mc.vanderbilt.edu version 1.0.0 This package serves as a proxy to the overall usage of the framework. This package is designed to generate synthetic data from a dataset from an original dataset using deep learning techniques @TODO: - Make configurable GPU, EPOCHS """ import pandas as pd import numpy as np import data.gan as gan from transport import factory def train (**args) : """ This function is intended to train the GAN in order to learn about the distribution of the features :column columns that need to be synthesized (discrete) :logs where the output of the (location on disk) :id identifier of the dataset :data data-frame to be synthesized :context label of what we are synthesizing """ column = args['column'] column_id = args['id'] df = args['data'] logs = args['logs'] real = pd.get_dummies(df[column]).astype(np.float32).values labels = pd.get_dummies(df[column_id]).astype(np.float32).values max_epochs = 10 if 'max_epochs' not in args else args['max_epochs'] context = args['context'] if 'store' in args : args['store']['args']['doc'] = context logger = factory.instance(**args['store']) else: logger = None trainer = gan.Train(context=context,max_epochs=max_epochs,real=real,label=labels,column=column,column_id=column_id,logger = logger,logs=logs) return trainer.apply() def generate(**args): """ This function will generate a synthetic dataset on the basis of a model that has been learnt for the dataset @return pandas.DataFrame :data data-frame to be synthesized :column columns that need to be synthesized (discrete) :id column identifying an entity :logs location on disk where the learnt knowledge of the dataset is """ df = args['data'] column = args['column'] column_id = args['id'] logs = args['logs'] context = args['context'] # #@TODO: # If the identifier is not present, we should fine a way to determine or make one # #ocolumns= list(set(df.columns.tolist())- set(columns)) values = df[column].unique().tolist() values.sort() labels = pd.get_dummies(df[column_id]).astype(np.float32).values handler = gan.Predict (context=context,label=labels,values=values,column=column) handler.load_meta(column) r = handler.apply() _df = df.copy() _df[column] = r[column] return _df