Bug fix: sqlalchemy facilities added

This commit is contained in:
Steve Nyemba 2022-03-03 16:08:24 -06:00
parent a7df7bfbce
commit 14a551e57b
2 changed files with 130 additions and 44 deletions

View File

@ -26,7 +26,7 @@ import numpy as np
import json import json
import importlib import importlib
import sys import sys
import sqlalchemy
if sys.version_info[0] > 2 : if sys.version_info[0] > 2 :
from transport.common import Reader, Writer #, factory from transport.common import Reader, Writer #, factory
from transport import disk from transport import disk
@ -59,8 +59,8 @@ class factory :
"postgresql":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}}, "postgresql":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
"redshift":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}}, "redshift":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
"bigquery":{"class":{"read":sql.BQReader,"write":sql.BQWriter}}, "bigquery":{"class":{"read":sql.BQReader,"write":sql.BQWriter}},
"mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}}, "mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
"mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}}, "mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
"mongo":{"port":27017,"host":"localhost","class":{"read":mongo.MongoReader,"write":mongo.MongoWriter}}, "mongo":{"port":27017,"host":"localhost","class":{"read":mongo.MongoReader,"write":mongo.MongoWriter}},
"couch":{"port":5984,"host":"localhost","class":{"read":couch.CouchReader,"write":couch.CouchWriter}}, "couch":{"port":5984,"host":"localhost","class":{"read":couch.CouchReader,"write":couch.CouchWriter}},
"netezza":{"port":5480,"driver":nz,"default":{"type":"VARCHAR(256)"}}} "netezza":{"port":5480,"driver":nz,"default":{"type":"VARCHAR(256)"}}}
@ -137,6 +137,37 @@ def instance(**_args):
pointer = factory.PROVIDERS[provider]['class'][_id] pointer = factory.PROVIDERS[provider]['class'][_id]
else: else:
pointer = sql.SQLReader if _id == 'read' else sql.SQLWriter pointer = sql.SQLReader if _id == 'read' else sql.SQLWriter
#
# Let us try to establish an sqlalchemy wrapper
try:
host = ''
if provider not in ['bigquery','mongodb','couchdb','sqlite'] :
#
# In these cases we are assuming RDBMS and thus would exclude NoSQL and BigQuery
username = args['username'] if 'username' in args else ''
password = args['password'] if 'password' in args else ''
if username == '' :
account = ''
else:
account = username + ':'+password+'@'
host = args['host']
if 'port' in args :
host = host+":"+str(args['port'])
database = args['database']
elif provider == 'sqlite':
account = ''
host = ''
database = args['path'] if 'path' in args else args['database']
if provider not in ['mongodb','couchdb','bigquery'] :
uri = ''.join([provider,"://",account,host,'/',database])
e = sqlalchemy.create_engine (uri)
args['sqlalchemy'] = e
#
# @TODO: Include handling of bigquery with SQLAlchemy
except Exception as e:
print (e)
return pointer(**args) return pointer(**args)

View File

@ -12,6 +12,8 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
import psycopg2 as pg import psycopg2 as pg
import mysql.connector as my import mysql.connector as my
import sys import sys
import sqlalchemy
if sys.version_info[0] > 2 : if sys.version_info[0] > 2 :
from transport.common import Reader, Writer #, factory from transport.common import Reader, Writer #, factory
else: else:
@ -44,7 +46,8 @@ class SQLRW :
_info['dbname'] = _args['db'] if 'db' in _args else _args['database'] _info['dbname'] = _args['db'] if 'db' in _args else _args['database']
self.table = _args['table'] if 'table' in _args else None self.table = _args['table'] if 'table' in _args else None
self.fields = _args['fields'] if 'fields' in _args else [] self.fields = _args['fields'] if 'fields' in _args else []
# _provider = _args['provider']
self._provider = _args['provider'] if 'provider' in _args else None
# _info['host'] = 'localhost' if 'host' not in _args else _args['host'] # _info['host'] = 'localhost' if 'host' not in _args else _args['host']
# _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port'] # _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port']
@ -59,7 +62,7 @@ class SQLRW :
if 'username' in _args or 'user' in _args: if 'username' in _args or 'user' in _args:
key = 'username' if 'username' in _args else 'user' key = 'username' if 'username' in _args else 'user'
_info['user'] = _args[key] _info['user'] = _args[key]
_info['password'] = _args['password'] _info['password'] = _args['password'] if 'password' in _args else ''
# #
# We need to load the drivers here to see what we are dealing with ... # We need to load the drivers here to see what we are dealing with ...
@ -74,17 +77,29 @@ class SQLRW :
_info['database'] = _info['dbname'] _info['database'] = _info['dbname']
_info['securityLevel'] = 0 _info['securityLevel'] = 0
del _info['dbname'] del _info['dbname']
if _handler == my :
_info['database'] = _info['dbname']
del _info['dbname']
self.conn = _handler.connect(**_info) self.conn = _handler.connect(**_info)
self._engine = _args['sqlalchemy'] if 'sqlalchemy' in _args else None
def has(self,**_args): def has(self,**_args):
found = False found = False
try: try:
table = _args['table'] table = _args['table']
sql = "SELECT * FROM :table LIMIT 1".replace(":table",table) sql = "SELECT * FROM :table LIMIT 1".replace(":table",table)
found = pd.read_sql(sql,self.conn).shape[0] if self._engine :
_conn = self._engine.connect()
else:
_conn = self.conn
found = pd.read_sql(sql,_conn).shape[0]
found = True found = True
except Exception as e: except Exception as e:
pass pass
finally:
if self._engine :
_conn.close()
return found return found
def isready(self): def isready(self):
_sql = "SELECT * FROM :table LIMIT 1".replace(":table",self.table) _sql = "SELECT * FROM :table LIMIT 1".replace(":table",self.table)
@ -104,7 +119,8 @@ class SQLRW :
try: try:
if "select" in _sql.lower() : if "select" in _sql.lower() :
cursor.close() cursor.close()
return pd.read_sql(_sql,self.conn) _conn = self._engine.connect() if self._engine else self.conn
return pd.read_sql(_sql,_conn)
else: else:
# Executing a command i.e no expected return values ... # Executing a command i.e no expected return values ...
cursor.execute(_sql) cursor.execute(_sql)
@ -123,6 +139,7 @@ class SQLRW :
class SQLReader(SQLRW,Reader) : class SQLReader(SQLRW,Reader) :
def __init__(self,**_args): def __init__(self,**_args):
super().__init__(**_args) super().__init__(**_args)
def read(self,**_args): def read(self,**_args):
if 'sql' in _args : if 'sql' in _args :
_sql = (_args['sql']) _sql = (_args['sql'])
@ -151,27 +168,47 @@ class SQLWriter(SQLRW,Writer):
# NOTE: Proper data type should be set on the target system if their source is unclear. # NOTE: Proper data type should be set on the target system if their source is unclear.
self._inspect = False if 'inspect' not in _args else _args['inspect'] self._inspect = False if 'inspect' not in _args else _args['inspect']
self._cast = False if 'cast' not in _args else _args['cast'] self._cast = False if 'cast' not in _args else _args['cast']
def init(self,fields=None): def init(self,fields=None):
if not fields : if not fields :
try: try:
self.fields = pd.read_sql("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist() self.fields = pd.read_sql_query("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
finally: finally:
pass pass
else: else:
self.fields = fields; self.fields = fields;
def make(self,fields): def make(self,**_args):
self.fields = fields
sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"]) if 'fields' in _args :
fields = _args['fields']
sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
else:
schema = _args['schema']
N = len(schema)
_map = _args['map'] if 'map' in _args else {}
sql = [] # ["CREATE TABLE ",_args['table'],"("]
for _item in schema :
_type = _item['type']
if _type in _map :
_type = _map[_type]
sql = sql + [" " .join([_item['name'], ' ',_type])]
sql = ",".join(sql)
sql = ["CREATE TABLE ",_args['table'],"( ",sql," )"]
sql = " ".join(sql)
# sql = " ".join(["CREATE TABLE",_args['table']," (", ",".join([ schema[i]['name'] +' '+ (schema[i]['type'] if schema[i]['type'] not in _map else _map[schema[i]['type'] ]) for i in range(0,N)]),")"])
cursor = self.conn.cursor() cursor = self.conn.cursor()
try: try:
cursor.execute(sql) cursor.execute(sql)
except Exception as e : except Exception as e :
print (e) print (e)
print (sql)
pass pass
finally: finally:
cursor.close() # cursor.close()
self.conn.commit()
pass
def write(self,info): def write(self,info):
""" """
:param info writes a list of data to a given set of fields :param info writes a list of data to a given set of fields
@ -184,7 +221,7 @@ class SQLWriter(SQLRW,Writer):
elif type(info) == dict : elif type(info) == dict :
_fields = info.keys() _fields = info.keys()
elif type(info) == pd.DataFrame : elif type(info) == pd.DataFrame :
_fields = info.columns _fields = info.columns.tolist()
# _fields = info.keys() if type(info) == dict else info[0].keys() # _fields = info.keys() if type(info) == dict else info[0].keys()
_fields = list (_fields) _fields = list (_fields)
@ -192,12 +229,13 @@ class SQLWriter(SQLRW,Writer):
# #
# @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy # @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy
# #
if type(info) != list : # if type(info) != list :
# # #
# We are assuming 2 cases i.e dict or pd.DataFrame # # We are assuming 2 cases i.e dict or pd.DataFrame
info = [info] if type(info) == dict else info.values.tolist() # info = [info] if type(info) == dict else info.values.tolist()
cursor = self.conn.cursor() cursor = self.conn.cursor()
try: try:
_sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields) _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
if self._inspect : if self._inspect :
for _row in info : for _row in info :
@ -223,26 +261,41 @@ class SQLWriter(SQLRW,Writer):
pass pass
else: else:
_fields = ",".join(self.fields)
# _sql = _sql.replace(":fields",_fields) # _sql = _sql.replace(":fields",_fields)
# _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields])) # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
# _sql = _sql.replace("(:fields)","") # _sql = _sql.replace("(:fields)","")
_sql = _sql.replace(":fields",_fields)
values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields]) # _sql = _sql.replace(":values",values)
_sql = _sql.replace(":values",values) # if type(info) == pd.DataFrame :
if type(info) == pd.DataFrame : # _info = info[self.fields].values.tolist()
_info = info[self.fields].values.tolist()
elif type(info) == dict : # elif type(info) == dict :
_info = info.values() # _info = info.values()
# else:
# # _info = []
# _info = pd.DataFrame(info)[self.fields].values.tolist()
# _info = pd.DataFrame(info).to_dict(orient='records')
if type(info) == list :
_info = pd.DataFrame(info)
elif type(info) == dict :
_info = pd.DataFrame([info])
else: else:
# _info = [] _info = pd.DataFrame(info)
_info = pd.DataFrame(info)[self.fields].values.tolist()
# for row in info :
# if type(row) == dict : if self._engine :
# _info.append( list(row.values())) # pd.to_sql(_info,self._engine)
cursor.executemany(_sql,_info) _info.to_sql(self.table,self._engine,if_exists='append',index=False)
else:
_fields = ",".join(self.fields)
_sql = _sql.replace(":fields",_fields)
values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
_sql = _sql.replace(":values",values)
cursor.executemany(_sql,_info.values.tolist())
# cursor.commit()
# self.conn.commit() # self.conn.commit()
except Exception as e: except Exception as e:
@ -250,7 +303,7 @@ class SQLWriter(SQLRW,Writer):
pass pass
finally: finally:
self.conn.commit() self.conn.commit()
cursor.close() # cursor.close()
pass pass
def close(self): def close(self):
try: try:
@ -265,6 +318,7 @@ class BigQuery:
self.path = path self.path = path
self.dtypes = _args['dtypes'] if 'dtypes' in _args else None self.dtypes = _args['dtypes'] if 'dtypes' in _args else None
self.table = _args['table'] if 'table' in _args else None self.table = _args['table'] if 'table' in _args else None
self.client = bq.Client.from_service_account_json(self.path)
def meta(self,**_args): def meta(self,**_args):
""" """
This function returns meta data for a given table or query with dataset/table properly formatted This function returns meta data for a given table or query with dataset/table properly formatted
@ -272,16 +326,16 @@ class BigQuery:
:param sql sql query to be pulled, :param sql sql query to be pulled,
""" """
table = _args['table'] table = _args['table']
client = bq.Client.from_service_account_json(self.path)
ref = client.dataset(self.dataset).table(table) ref = self.client.dataset(self.dataset).table(table)
return client.get_table(ref).schema return self.client.get_table(ref).schema
def has(self,**_args): def has(self,**_args):
found = False found = False
try: try:
found = self.meta(**_args) is not None found = self.meta(**_args) is not None
except Exception as e: except Exception as e:
pass pass
return found return found
class BQReader(BigQuery,Reader) : class BQReader(BigQuery,Reader) :
def __init__(self,**_args): def __init__(self,**_args):
@ -305,7 +359,8 @@ class BQReader(BigQuery,Reader) :
SQL = SQL.replace(':dataset',self.dataset).replace(':DATASET',self.dataset) SQL = SQL.replace(':dataset',self.dataset).replace(':DATASET',self.dataset)
_info = {'credentials':self.credentials,'dialect':'standard'} _info = {'credentials':self.credentials,'dialect':'standard'}
return pd.read_gbq(SQL,**_info) if SQL else None return pd.read_gbq(SQL,**_info) if SQL else None
# return pd.read_gbq(SQL,credentials=self.credentials,dialect='standard') if SQL else None # return self.client.query(SQL).to_dataframe() if SQL else None
class BQWriter(BigQuery,Writer): class BQWriter(BigQuery,Writer):
lock = Lock() lock = Lock()