From 140a4c4573a0d63ecadd592677fba4ff7ae15f3d Mon Sep 17 00:00:00 2001 From: "Steve L. Nyemba -- The Architect" Date: Thu, 27 Sep 2018 10:33:52 -0500 Subject: [PATCH] bug fix: prosecutor risk, marketer risk --- notebooks/risk.ipynb | 205 +++++++++++++++++++++++++-------------- src/params.py | 17 ++++ src/risk.py | 226 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 374 insertions(+), 74 deletions(-) create mode 100644 src/params.py create mode 100644 src/risk.py diff --git a/notebooks/risk.ipynb b/notebooks/risk.ipynb index fc86de5..1109529 100644 --- a/notebooks/risk.ipynb +++ b/notebooks/risk.ipynb @@ -2,15 +2,29 @@ "cells": [ { "cell_type": "code", - "execution_count": 66, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dev-deid-600@aou-res-deid-vumc-test.iam.gserviceaccount.com df0ac049-d5b6-416f-ab3c-6321eda919d6 2018-09-25 08:18:34.829000+00:00 DONE\n" + ] + } + ], "source": [ "import pandas as pd\n", "import numpy as np\n", "from google.cloud import bigquery as bq\n", "\n", - "client = bq.Client.from_service_account_json('/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json')" + "client = bq.Client.from_service_account_json('/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json')\n", + "# pd.read_gbq(query=\"select * from raw.observation limit 10\",private_key='/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json')\n", + "jobs = client.list_jobs()\n", + "for job in jobs :\n", + "# print dir(job)\n", + " print job.user_email,job.job_id,job.started, job.state\n", + " break" ] }, { @@ -25,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 181, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -68,7 +82,7 @@ " else:\n", " x_ = args['xi']\n", " for xi in x_ :\n", - " fields += (['.'.join([xi['name'],name]) for name in xi['fields'] if name != args['join']])\n", + " fields += (['.'.join([xi['name'], name]) for name in xi['fields'] if name != args['join']])\n", " return fields\n", "def generate_sql(**args):\n", " \"\"\"\n", @@ -97,7 +111,27 @@ " tmp.append(ON_SQL)\n", " INNER_JOINS += [JOIN_SQL + \" AND \".join(tmp)]\n", " return SQL + \" \".join(INNER_JOINS)\n", - " \n", + "def get_final_sql(**args):\n", + " xo = args['xo']\n", + " xi = args['xi']\n", + " join=args['join']\n", + " prefix = args['prefix'] if 'prefix' in args else ''\n", + " fields = get_fields (xo=xo,xi=xi,join=join)\n", + " k = len(fields)\n", + " n = np.random.randint(2,k) #-- number of fields to select\n", + " i = np.random.randint(0,k,size=n)\n", + " fields = [name for name in fields if fields.index(name) in i]\n", + " base_sql = generate_sql(xo=xo,xi=xi,prefix)\n", + " SQL = \"\"\"\n", + " SELECT AVERAGE(count),size,n as selected_features,k as total_features\n", + " FROM(\n", + " SELECT COUNT(*) as count,count(:join) as pop,sum(:n) as N,sum(:k) as k,:fields\n", + " FROM (:sql)\n", + " GROUP BY :fields\n", + " ) \n", + " order by 1\n", + " \n", + " \"\"\".replace(\":sql\",base_sql)\n", "# sql = \"SELECT :fields FROM :xo.name INNER JOIN :xi.name ON :xi.name.:xi.y = :xo.y \"\n", "# fields = \",\".join(get_fields(xo=xi,xi=xi,join=xi['y']))\n", " \n", @@ -111,24 +145,39 @@ }, { "cell_type": "code", - "execution_count": 183, + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "xo = {\"name\":\"person\",\"fields\":['person_id','date_of_birth','race','value_as_number']}\n", + "xi = [{\"name\":\"measurement\",\"fields\":['person_id','value_as_number','value_source_value']}] #,{\"name\":\"observation\",\"fields\":[\"person_id\",\"value_as_string\",\"observation_source_value\"]}]\n", + "# generate_sql(xo=xo,xi=xi,join=\"person_id\",prefix='raw')\n", + "fields = get_fields(xo=xo,xi=xi,join='person_id')\n", + "ofields = list(fields)\n", + "k = len(fields)\n", + "n = np.random.randint(2,k) #-- number of fields to select\n", + "i = np.random.randint(0,k,size=n)\n", + "fields = [name for name in fields if fields.index(name) in i]" + ] + }, + { + "cell_type": "code", + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'SELECT :fields FROM raw.person INNER JOIN raw.measurement ON measurement.person_id = person.person_id'" + "['person.race', 'person.value_as_number', 'measurement.value_source_value']" ] }, - "execution_count": 183, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "xo = {\"name\":\"person\",\"fields\":['person_id','date_of_birth','race']}\n", - "xi = [{\"name\":\"measurement\",\"fields\":['person_id','value_as_number','value_source_value']}] #,{\"name\":\"observation\",\"fields\":[\"person_id\",\"value_as_string\",\"observation_source_value\"]}]\n", - "generate_sql(xo=xo,xi=xi,join=\"person_id\",prefix='raw')" + "fields\n" ] }, { @@ -179,69 +228,16 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[u'condition_occurrence.condition_occurrence_id',\n", - " u'condition_occurrence.person_id',\n", - " u'condition_occurrence.condition_concept_id',\n", - " u'condition_occurrence.condition_start_date',\n", - " u'condition_occurrence.condition_start_datetime',\n", - " u'condition_occurrence.condition_end_date',\n", - " u'condition_occurrence.condition_end_datetime',\n", - " u'condition_occurrence.condition_type_concept_id',\n", - " u'condition_occurrence.stop_reason',\n", - " u'condition_occurrence.provider_id',\n", - " u'condition_occurrence.visit_occurrence_id',\n", - " u'condition_occurrence.condition_source_value',\n", - " u'condition_occurrence.condition_source_concept_id',\n", - " u'death.death_date',\n", - " u'death.death_datetime',\n", - " u'death.death_type_concept_id',\n", - " u'death.cause_concept_id',\n", - " u'death.cause_source_value',\n", - " u'death.cause_source_concept_id',\n", - " u'device_exposure.device_exposure_id',\n", - " u'device_exposure.device_concept_id',\n", - " u'device_exposure.device_exposure_start_date',\n", - " u'device_exposure.device_exposure_start_datetime',\n", - " u'device_exposure.device_exposure_end_date',\n", - " u'device_exposure.device_exposure_end_datetime',\n", - " u'device_exposure.device_type_concept_id',\n", - " u'device_exposure.unique_device_id',\n", - " u'device_exposure.quantity',\n", - " u'device_exposure.provider_id',\n", - " u'device_exposure.visit_occurrence_id',\n", - " u'device_exposure.device_source_value',\n", - " u'device_exposure.device_source_concept_id',\n", - " u'drug_exposure.drug_exposure_id',\n", - " u'drug_exposure.drug_concept_id',\n", - " u'drug_exposure.drug_exposure_start_date',\n", - " u'drug_exposure.drug_exposure_start_datetime',\n", - " u'drug_exposure.drug_exposure_end_date',\n", - " u'drug_exposure.drug_exposure_end_datetime',\n", - " u'drug_exposure.drug_type_concept_id',\n", - " u'drug_exposure.stop_reason',\n", - " u'drug_exposure.refills',\n", - " u'drug_exposure.quantity',\n", - " u'drug_exposure.days_supply',\n", - " u'drug_exposure.sig',\n", - " u'drug_exposure.route_concept_id',\n", - " u'drug_exposure.effective_drug_dose',\n", - " u'drug_exposure.dose_unit_concept_id',\n", - " u'drug_exposure.lot_number',\n", - " u'drug_exposure.provider_id',\n", - " u'drug_exposure.visit_occurrence_id',\n", - " u'drug_exposure.drug_source_value',\n", - " u'drug_exposure.drug_source_concept_id',\n", - " u'drug_exposure.route_source_value',\n", - " u'drug_exposure.dose_unit_source_value']" + "array([1, 3, 0, 0])" ] }, - "execution_count": 111, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -250,12 +246,7 @@ "#\n", "# find every table with person id at the very least or a subset of fields\n", "#\n", - "info = get_tables(client,'raw',['person_id'])\n", - "# get_fields(xo=names[0],xi=names[1:4],join='person_id')\n", - "\n", - "# q = ['person_id']\n", - "# pairs = list(itertools.combinations(names,len(names)))\n", - "# pairs[0]" + "np.random.randint(0,4,size=4)" ] }, { @@ -287,6 +278,72 @@ "x_ = 1" ] }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "x_ = pd.DataFrame({\"group\":[1,1,1,1,1], \"size\":[2,1,1,1,1]})" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
size
group
11.2
\n", + "
" + ], + "text/plain": [ + " size\n", + "group \n", + "1 1.2" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_.groupby(['group']).mean()\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/params.py b/src/params.py new file mode 100644 index 0000000..428ff00 --- /dev/null +++ b/src/params.py @@ -0,0 +1,17 @@ +import sys +SYS_ARGS={} +if len(sys.argv) > 1 : + N = len(sys.argv) + for i in range(1,N) : + value = 1 + + if sys.argv[i].startswith('--') : + key = sys.argv[i].replace('-','') + + if i + 1 < N and not sys.argv[i+1].startswith('--') : + value = sys.argv[i + 1].strip() + SYS_ARGS[key] = value + i += 2 + elif 'action' not in SYS_ARGS: + SYS_ARGS['action'] = sys.argv[i].strip() + diff --git a/src/risk.py b/src/risk.py new file mode 100644 index 0000000..b942b14 --- /dev/null +++ b/src/risk.py @@ -0,0 +1,226 @@ +""" + Steve L. Nyemba & Brad Malin + Health Information Privacy Lab. + + This code is proof of concept as to how risk is computed against a database (at least a schema). + The engine will read tables that have a given criteria (patient id) and generate a dataset by performing joins. + Because joins are process intensive we decided to add a limit to the records pulled. + + TL;DR: + This engine generates a dataset and computes risk (marketer and prosecutor) + Assumptions: + - We assume tables that reference patients will name the keys identically (best practice). This allows us to be able to leverage data store's that don't support referential integrity + + Usage : + + Limitations + - It works against bigquery for now + @TODO: + - Need to write a transport layer (database interface) + - Support for referential integrity, so one table can be selected and a dataset derived given referential integrity + - Add support for journalist risk +""" +import pandas as pd +import numpy as np +from google.cloud import bigquery as bq +import time +from params import SYS_ARGS +class utils : + """ + This class is a utility class that will generate SQL-11 compatible code in order to run the risk assessment + + @TODO: plugins for other data-stores + """ + def __init__(self,**args): + # self.path = args['path'] + self.client = args['client'] + + def get_tables(self,**args): #id,key='person_id'): + """ + This function returns a list of tables given a key. The key is the name of the field that uniquely designates a patient/person + in the database. The list of tables are tables that can be joined given the provided field. + + @param key name of the patient field + @param dataset dataset name + @param client initialized bigquery client () + @return [{name,fields:[],row_count}] + """ + dataset = args['dataset'] + client = args['client'] + key = args['key'] + r = [] + ref = client.dataset(dataset) + tables = list(client.list_tables(ref)) + for table in tables : + + if table.table_id.strip() in ['people_seed']: + print ' skiping ...' + continue + ref = table.reference + table = client.get_table(ref) + schema = table.schema + rows = table.num_rows + if rows == 0 : + continue + names = [f.name for f in schema] + x = list(set(names) & set([key])) + if x : + full_name = ".".join([dataset,table.table_id]) + r.append({"name":table.table_id,"fields":names,"row_count":rows,"full_name":full_name}) + return r + def get_field_name(self,alias,field_name,index): + """ + This function will format the a field name given an index (the number of times it has occurred in projection) + The index is intended to avoid a "duplicate field" error (bigquery issue) + + @param alias alias of the table + @param field_name name of the field to be formatted + @param index the number of times the field appears in the projection + """ + name = [alias,field_name] + if index > 0 : + return ".".join(name)+" AS :field_name:index".replace(":field_name",field_name).replace(":index",str(index)) + else: + return ".".join(name) + def get_sql(self,**args): + """ + This function will generate that will join a list of tables given a key and a limit of records + @param tables list of tables + @param key key field to be used in the join. The assumption is that the field name is identical across tables (best practice!) + @param limit a limit imposed, in case of ristrictions considering joins are resource intensive + """ + tables = args['tables'] + key = args['key'] + limit = args['limit'] if 'limit' in args else 300000 + limit = str(limit) + SQL = [ + """ + SELECT :fields + FROM + """] + fields = [] + prev_table = None + for table in tables : + name = table['full_name'] #".".join([self.i_dataset,table['name']]) + alias= table['name'] + index = tables.index(table) + sql_ = """ + (select * from :name limit :limit) as :alias + """.replace(":limit",limit) + sql_ = sql_.replace(":name",name).replace(":alias",alias) + fields += [self.get_field_name(alias,field_name,index) for field_name in table['fields'] if field_name != key or (field_name==key and tables.index(table) == 0) ] + if tables.index(table) > 0 : + join = """ + INNER JOIN :sql ON :alias.:field = :prev_alias.:field + """.replace(":name",name) + join = join.replace(":alias",alias).replace(":field",key).replace(":prev_alias",prev_alias) + sql_ = join.replace(":sql",sql_) + # sql_ = " ".join([sql_,join]) + SQL += [sql_] + if index == 0: + prev_alias = str(alias) + + return " ".join(SQL).replace(":fields"," , ".join(fields)) + +class risk : + """ + This class will handle the creation of an SQL query that computes marketer and prosecutor risk (for now) + """ + def __init__(self): + pass + def get_sql(self,**args) : + """ + This function returns the SQL Query that will compute marketer and prosecutor risk + @param key key fields (patient identifier) + @param table table that is subject of the computation + """ + key = args['key'] + table = args['table'] + fields = list(set(table['fields']) - set([key])) + #-- We need to select n-fields max 64 + k = len(fields) + n = np.random.randint(2,24) #-- how many random fields are we processing + ii = np.random.choice(k,n,replace=False) + fields = list(np.array(fields)[ii]) + + sql = """ + SELECT COUNT(g_size) as group_count, SUM(g_size) as patient_count, COUNT(g_size)/SUM(g_size) as marketer, 1/ MIN(g_size) as prosecutor + FROM ( + SELECT COUNT(*) as g_size,:key,:fields + FROM :full_name + GROUP BY :key,:fields + ) + """.replace(":fields", ",".join(fields)).replace(":full_name",table['full_name']).replace(":key",key).replace(":n",str(n)) + return sql + + + + + +if 'action' in SYS_ARGS and SYS_ARGS['action'] in ['create','compute'] : + + path = SYS_ARGS['path'] + client = bq.Client.from_service_account_json(path) + i_dataset = SYS_ARGS['i_dataset'] + key = SYS_ARGS['key'] + + mytools = utils(client = client) + tables = mytools.get_tables(dataset=i_dataset,client=client,key=key) + # print len(tables) + # tables = tables[:6] + + if SYS_ARGS['action'] == 'create' : + #usage: + # create --i_dataset --key --o_dataset --table [--file] --path + # + create_sql = mytools.get_sql(tables=tables,key=key) #-- The create statement + o_dataset = SYS_ARGS['o_dataset'] + table = SYS_ARGS['table'] + if 'file' in SYS_ARGS : + f = open(table+'.sql','w') + f.write(create_sql) + f.close() + else: + job = bq.QueryJobConfig() + job.destination = client.dataset(o_dataset).table(table) + job.use_query_cache = True + job.allow_large_results = True + job.priority = 'BATCH' + job.time_partitioning = bq.table.TimePartitioning(type_=bq.table.TimePartitioningType.DAY) + + r = client.query(create_sql,location='US',job_config=job) + + print [r.job_id,' ** ',r.state] + else: + # + # + tables = [tab for tab in tables if tab['name'] == SYS_ARGS['table'] ] + if tables : + risk = risk() + df = pd.DataFrame() + for i in range(0,10) : + sql = risk.get_sql(key=SYS_ARGS['key'],table=tables[0]) + df = df.append(pd.read_gbq(query=sql,private_key=path,dialect='standard')) + df.to_csv(SYS_ARGS['table']+'.csv') + print [i,' ** ',df.shape[0]] + time.sleep(2) + + pass +else: + print 'ERROR' + pass + +# r = risk(path='/home/steve/dev/google-cloud-sdk/accounts/vumc-test.json', i_dataset='raw',o_dataset='risk_o',o_table='mo') +# tables = r.get_tables('raw','person_id') +# sql = r.get_sql(tables=tables[:3],key='person_id') +# # +# # let's post this to a designated location +# # +# f = open('foo.sql','w') +# f.write(sql) +# f.close() +# r.get_sql(tables=tables,key='person_id') +# p = r.compute() +# print p +# p.to_csv("risk.csv") +# r.write('foo.sql') \ No newline at end of file