Bug fix: sqlalchemy facilities added

This commit is contained in:
Steve Nyemba 2022-03-03 16:08:24 -06:00
parent a7df7bfbce
commit 14a551e57b
2 changed files with 130 additions and 44 deletions

View File

@ -26,7 +26,7 @@ import numpy as np
import json
import importlib
import sys
import sqlalchemy
if sys.version_info[0] > 2 :
from transport.common import Reader, Writer #, factory
from transport import disk
@ -59,8 +59,8 @@ class factory :
"postgresql":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
"redshift":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
"bigquery":{"class":{"read":sql.BQReader,"write":sql.BQWriter}},
"mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}},
"mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}},
"mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
"mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
"mongo":{"port":27017,"host":"localhost","class":{"read":mongo.MongoReader,"write":mongo.MongoWriter}},
"couch":{"port":5984,"host":"localhost","class":{"read":couch.CouchReader,"write":couch.CouchWriter}},
"netezza":{"port":5480,"driver":nz,"default":{"type":"VARCHAR(256)"}}}
@ -137,6 +137,37 @@ def instance(**_args):
pointer = factory.PROVIDERS[provider]['class'][_id]
else:
pointer = sql.SQLReader if _id == 'read' else sql.SQLWriter
#
# Let us try to establish an sqlalchemy wrapper
try:
host = ''
if provider not in ['bigquery','mongodb','couchdb','sqlite'] :
#
# In these cases we are assuming RDBMS and thus would exclude NoSQL and BigQuery
username = args['username'] if 'username' in args else ''
password = args['password'] if 'password' in args else ''
if username == '' :
account = ''
else:
account = username + ':'+password+'@'
host = args['host']
if 'port' in args :
host = host+":"+str(args['port'])
database = args['database']
elif provider == 'sqlite':
account = ''
host = ''
database = args['path'] if 'path' in args else args['database']
if provider not in ['mongodb','couchdb','bigquery'] :
uri = ''.join([provider,"://",account,host,'/',database])
e = sqlalchemy.create_engine (uri)
args['sqlalchemy'] = e
#
# @TODO: Include handling of bigquery with SQLAlchemy
except Exception as e:
print (e)
return pointer(**args)

View File

@ -12,6 +12,8 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
import psycopg2 as pg
import mysql.connector as my
import sys
import sqlalchemy
if sys.version_info[0] > 2 :
from transport.common import Reader, Writer #, factory
else:
@ -44,7 +46,8 @@ class SQLRW :
_info['dbname'] = _args['db'] if 'db' in _args else _args['database']
self.table = _args['table'] if 'table' in _args else None
self.fields = _args['fields'] if 'fields' in _args else []
# _provider = _args['provider']
self._provider = _args['provider'] if 'provider' in _args else None
# _info['host'] = 'localhost' if 'host' not in _args else _args['host']
# _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port']
@ -59,7 +62,7 @@ class SQLRW :
if 'username' in _args or 'user' in _args:
key = 'username' if 'username' in _args else 'user'
_info['user'] = _args[key]
_info['password'] = _args['password']
_info['password'] = _args['password'] if 'password' in _args else ''
#
# We need to load the drivers here to see what we are dealing with ...
@ -74,17 +77,29 @@ class SQLRW :
_info['database'] = _info['dbname']
_info['securityLevel'] = 0
del _info['dbname']
if _handler == my :
_info['database'] = _info['dbname']
del _info['dbname']
self.conn = _handler.connect(**_info)
self._engine = _args['sqlalchemy'] if 'sqlalchemy' in _args else None
def has(self,**_args):
found = False
try:
table = _args['table']
sql = "SELECT * FROM :table LIMIT 1".replace(":table",table)
found = pd.read_sql(sql,self.conn).shape[0]
if self._engine :
_conn = self._engine.connect()
else:
_conn = self.conn
found = pd.read_sql(sql,_conn).shape[0]
found = True
except Exception as e:
pass
finally:
if self._engine :
_conn.close()
return found
def isready(self):
_sql = "SELECT * FROM :table LIMIT 1".replace(":table",self.table)
@ -104,7 +119,8 @@ class SQLRW :
try:
if "select" in _sql.lower() :
cursor.close()
return pd.read_sql(_sql,self.conn)
_conn = self._engine.connect() if self._engine else self.conn
return pd.read_sql(_sql,_conn)
else:
# Executing a command i.e no expected return values ...
cursor.execute(_sql)
@ -123,6 +139,7 @@ class SQLRW :
class SQLReader(SQLRW,Reader) :
def __init__(self,**_args):
super().__init__(**_args)
def read(self,**_args):
if 'sql' in _args :
_sql = (_args['sql'])
@ -151,27 +168,47 @@ class SQLWriter(SQLRW,Writer):
# NOTE: Proper data type should be set on the target system if their source is unclear.
self._inspect = False if 'inspect' not in _args else _args['inspect']
self._cast = False if 'cast' not in _args else _args['cast']
def init(self,fields=None):
if not fields :
try:
self.fields = pd.read_sql("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
self.fields = pd.read_sql_query("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
finally:
pass
else:
self.fields = fields;
def make(self,fields):
self.fields = fields
def make(self,**_args):
if 'fields' in _args :
fields = _args['fields']
sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
else:
schema = _args['schema']
N = len(schema)
_map = _args['map'] if 'map' in _args else {}
sql = [] # ["CREATE TABLE ",_args['table'],"("]
for _item in schema :
_type = _item['type']
if _type in _map :
_type = _map[_type]
sql = sql + [" " .join([_item['name'], ' ',_type])]
sql = ",".join(sql)
sql = ["CREATE TABLE ",_args['table'],"( ",sql," )"]
sql = " ".join(sql)
# sql = " ".join(["CREATE TABLE",_args['table']," (", ",".join([ schema[i]['name'] +' '+ (schema[i]['type'] if schema[i]['type'] not in _map else _map[schema[i]['type'] ]) for i in range(0,N)]),")"])
cursor = self.conn.cursor()
try:
cursor.execute(sql)
except Exception as e :
print (e)
print (sql)
pass
finally:
cursor.close()
# cursor.close()
self.conn.commit()
pass
def write(self,info):
"""
:param info writes a list of data to a given set of fields
@ -184,7 +221,7 @@ class SQLWriter(SQLRW,Writer):
elif type(info) == dict :
_fields = info.keys()
elif type(info) == pd.DataFrame :
_fields = info.columns
_fields = info.columns.tolist()
# _fields = info.keys() if type(info) == dict else info[0].keys()
_fields = list (_fields)
@ -192,12 +229,13 @@ class SQLWriter(SQLRW,Writer):
#
# @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy
#
if type(info) != list :
#
# We are assuming 2 cases i.e dict or pd.DataFrame
info = [info] if type(info) == dict else info.values.tolist()
# if type(info) != list :
# #
# # We are assuming 2 cases i.e dict or pd.DataFrame
# info = [info] if type(info) == dict else info.values.tolist()
cursor = self.conn.cursor()
try:
_sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
if self._inspect :
for _row in info :
@ -223,26 +261,41 @@ class SQLWriter(SQLRW,Writer):
pass
else:
_fields = ",".join(self.fields)
# _sql = _sql.replace(":fields",_fields)
# _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
# _sql = _sql.replace("(:fields)","")
# _sql = _sql.replace(":values",values)
# if type(info) == pd.DataFrame :
# _info = info[self.fields].values.tolist()
# elif type(info) == dict :
# _info = info.values()
# else:
# # _info = []
# _info = pd.DataFrame(info)[self.fields].values.tolist()
# _info = pd.DataFrame(info).to_dict(orient='records')
if type(info) == list :
_info = pd.DataFrame(info)
elif type(info) == dict :
_info = pd.DataFrame([info])
else:
_info = pd.DataFrame(info)
if self._engine :
# pd.to_sql(_info,self._engine)
_info.to_sql(self.table,self._engine,if_exists='append',index=False)
else:
_fields = ",".join(self.fields)
_sql = _sql.replace(":fields",_fields)
values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
_sql = _sql.replace(":values",values)
if type(info) == pd.DataFrame :
_info = info[self.fields].values.tolist()
elif type(info) == dict :
_info = info.values()
else:
# _info = []
_info = pd.DataFrame(info)[self.fields].values.tolist()
# for row in info :
# if type(row) == dict :
# _info.append( list(row.values()))
cursor.executemany(_sql,_info)
cursor.executemany(_sql,_info.values.tolist())
# cursor.commit()
# self.conn.commit()
except Exception as e:
@ -250,7 +303,7 @@ class SQLWriter(SQLRW,Writer):
pass
finally:
self.conn.commit()
cursor.close()
# cursor.close()
pass
def close(self):
try:
@ -265,6 +318,7 @@ class BigQuery:
self.path = path
self.dtypes = _args['dtypes'] if 'dtypes' in _args else None
self.table = _args['table'] if 'table' in _args else None
self.client = bq.Client.from_service_account_json(self.path)
def meta(self,**_args):
"""
This function returns meta data for a given table or query with dataset/table properly formatted
@ -272,9 +326,9 @@ class BigQuery:
:param sql sql query to be pulled,
"""
table = _args['table']
client = bq.Client.from_service_account_json(self.path)
ref = client.dataset(self.dataset).table(table)
return client.get_table(ref).schema
ref = self.client.dataset(self.dataset).table(table)
return self.client.get_table(ref).schema
def has(self,**_args):
found = False
try:
@ -305,7 +359,8 @@ class BQReader(BigQuery,Reader) :
SQL = SQL.replace(':dataset',self.dataset).replace(':DATASET',self.dataset)
_info = {'credentials':self.credentials,'dialect':'standard'}
return pd.read_gbq(SQL,**_info) if SQL else None
# return pd.read_gbq(SQL,credentials=self.credentials,dialect='standard') if SQL else None
# return self.client.query(SQL).to_dataframe() if SQL else None
class BQWriter(BigQuery,Writer):
lock = Lock()