Bug fix: sqlalchemy facilities added
This commit is contained in:
parent
a7df7bfbce
commit
14a551e57b
|
@ -26,7 +26,7 @@ import numpy as np
|
|||
import json
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import sqlalchemy
|
||||
if sys.version_info[0] > 2 :
|
||||
from transport.common import Reader, Writer #, factory
|
||||
from transport import disk
|
||||
|
@ -59,8 +59,8 @@ class factory :
|
|||
"postgresql":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
|
||||
"redshift":{"port":5432,"host":"localhost","database":os.environ['USER'],"driver":pg,"default":{"type":"VARCHAR"}},
|
||||
"bigquery":{"class":{"read":sql.BQReader,"write":sql.BQWriter}},
|
||||
"mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}},
|
||||
"mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"}},
|
||||
"mysql":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
|
||||
"mariadb":{"port":3306,"host":"localhost","default":{"type":"VARCHAR(256)"},"driver":my},
|
||||
"mongo":{"port":27017,"host":"localhost","class":{"read":mongo.MongoReader,"write":mongo.MongoWriter}},
|
||||
"couch":{"port":5984,"host":"localhost","class":{"read":couch.CouchReader,"write":couch.CouchWriter}},
|
||||
"netezza":{"port":5480,"driver":nz,"default":{"type":"VARCHAR(256)"}}}
|
||||
|
@ -137,7 +137,38 @@ def instance(**_args):
|
|||
pointer = factory.PROVIDERS[provider]['class'][_id]
|
||||
else:
|
||||
pointer = sql.SQLReader if _id == 'read' else sql.SQLWriter
|
||||
|
||||
#
|
||||
# Let us try to establish an sqlalchemy wrapper
|
||||
try:
|
||||
host = ''
|
||||
if provider not in ['bigquery','mongodb','couchdb','sqlite'] :
|
||||
#
|
||||
# In these cases we are assuming RDBMS and thus would exclude NoSQL and BigQuery
|
||||
username = args['username'] if 'username' in args else ''
|
||||
password = args['password'] if 'password' in args else ''
|
||||
if username == '' :
|
||||
account = ''
|
||||
else:
|
||||
account = username + ':'+password+'@'
|
||||
host = args['host']
|
||||
if 'port' in args :
|
||||
host = host+":"+str(args['port'])
|
||||
|
||||
database = args['database']
|
||||
elif provider == 'sqlite':
|
||||
account = ''
|
||||
host = ''
|
||||
database = args['path'] if 'path' in args else args['database']
|
||||
if provider not in ['mongodb','couchdb','bigquery'] :
|
||||
uri = ''.join([provider,"://",account,host,'/',database])
|
||||
|
||||
e = sqlalchemy.create_engine (uri)
|
||||
args['sqlalchemy'] = e
|
||||
#
|
||||
# @TODO: Include handling of bigquery with SQLAlchemy
|
||||
except Exception as e:
|
||||
print (e)
|
||||
|
||||
return pointer(**args)
|
||||
|
||||
return None
|
||||
|
|
135
transport/sql.py
135
transport/sql.py
|
@ -12,6 +12,8 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
|
|||
import psycopg2 as pg
|
||||
import mysql.connector as my
|
||||
import sys
|
||||
|
||||
import sqlalchemy
|
||||
if sys.version_info[0] > 2 :
|
||||
from transport.common import Reader, Writer #, factory
|
||||
else:
|
||||
|
@ -44,7 +46,8 @@ class SQLRW :
|
|||
_info['dbname'] = _args['db'] if 'db' in _args else _args['database']
|
||||
self.table = _args['table'] if 'table' in _args else None
|
||||
self.fields = _args['fields'] if 'fields' in _args else []
|
||||
# _provider = _args['provider']
|
||||
|
||||
self._provider = _args['provider'] if 'provider' in _args else None
|
||||
# _info['host'] = 'localhost' if 'host' not in _args else _args['host']
|
||||
# _info['port'] = SQLWriter.REFERENCE[_provider]['port'] if 'port' not in _args else _args['port']
|
||||
|
||||
|
@ -59,7 +62,7 @@ class SQLRW :
|
|||
if 'username' in _args or 'user' in _args:
|
||||
key = 'username' if 'username' in _args else 'user'
|
||||
_info['user'] = _args[key]
|
||||
_info['password'] = _args['password']
|
||||
_info['password'] = _args['password'] if 'password' in _args else ''
|
||||
#
|
||||
# We need to load the drivers here to see what we are dealing with ...
|
||||
|
||||
|
@ -74,17 +77,29 @@ class SQLRW :
|
|||
_info['database'] = _info['dbname']
|
||||
_info['securityLevel'] = 0
|
||||
del _info['dbname']
|
||||
if _handler == my :
|
||||
_info['database'] = _info['dbname']
|
||||
del _info['dbname']
|
||||
|
||||
self.conn = _handler.connect(**_info)
|
||||
self._engine = _args['sqlalchemy'] if 'sqlalchemy' in _args else None
|
||||
def has(self,**_args):
|
||||
found = False
|
||||
try:
|
||||
table = _args['table']
|
||||
sql = "SELECT * FROM :table LIMIT 1".replace(":table",table)
|
||||
found = pd.read_sql(sql,self.conn).shape[0]
|
||||
if self._engine :
|
||||
_conn = self._engine.connect()
|
||||
else:
|
||||
_conn = self.conn
|
||||
found = pd.read_sql(sql,_conn).shape[0]
|
||||
found = True
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
finally:
|
||||
if self._engine :
|
||||
_conn.close()
|
||||
return found
|
||||
def isready(self):
|
||||
_sql = "SELECT * FROM :table LIMIT 1".replace(":table",self.table)
|
||||
|
@ -104,7 +119,8 @@ class SQLRW :
|
|||
try:
|
||||
if "select" in _sql.lower() :
|
||||
cursor.close()
|
||||
return pd.read_sql(_sql,self.conn)
|
||||
_conn = self._engine.connect() if self._engine else self.conn
|
||||
return pd.read_sql(_sql,_conn)
|
||||
else:
|
||||
# Executing a command i.e no expected return values ...
|
||||
cursor.execute(_sql)
|
||||
|
@ -122,7 +138,8 @@ class SQLRW :
|
|||
pass
|
||||
class SQLReader(SQLRW,Reader) :
|
||||
def __init__(self,**_args):
|
||||
super().__init__(**_args)
|
||||
super().__init__(**_args)
|
||||
|
||||
def read(self,**_args):
|
||||
if 'sql' in _args :
|
||||
_sql = (_args['sql'])
|
||||
|
@ -151,27 +168,47 @@ class SQLWriter(SQLRW,Writer):
|
|||
# NOTE: Proper data type should be set on the target system if their source is unclear.
|
||||
self._inspect = False if 'inspect' not in _args else _args['inspect']
|
||||
self._cast = False if 'cast' not in _args else _args['cast']
|
||||
|
||||
def init(self,fields=None):
|
||||
if not fields :
|
||||
try:
|
||||
self.fields = pd.read_sql("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
|
||||
self.fields = pd.read_sql_query("SELECT * FROM :table LIMIT 1".replace(":table",self.table),self.conn).columns.tolist()
|
||||
finally:
|
||||
pass
|
||||
else:
|
||||
self.fields = fields;
|
||||
|
||||
def make(self,fields):
|
||||
self.fields = fields
|
||||
|
||||
sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
|
||||
def make(self,**_args):
|
||||
|
||||
if 'fields' in _args :
|
||||
fields = _args['fields']
|
||||
sql = " ".join(["CREATE TABLE",self.table," (", ",".join([ name +' '+ self._dtype for name in fields]),")"])
|
||||
else:
|
||||
schema = _args['schema']
|
||||
N = len(schema)
|
||||
_map = _args['map'] if 'map' in _args else {}
|
||||
sql = [] # ["CREATE TABLE ",_args['table'],"("]
|
||||
for _item in schema :
|
||||
_type = _item['type']
|
||||
if _type in _map :
|
||||
_type = _map[_type]
|
||||
sql = sql + [" " .join([_item['name'], ' ',_type])]
|
||||
sql = ",".join(sql)
|
||||
sql = ["CREATE TABLE ",_args['table'],"( ",sql," )"]
|
||||
sql = " ".join(sql)
|
||||
# sql = " ".join(["CREATE TABLE",_args['table']," (", ",".join([ schema[i]['name'] +' '+ (schema[i]['type'] if schema[i]['type'] not in _map else _map[schema[i]['type'] ]) for i in range(0,N)]),")"])
|
||||
cursor = self.conn.cursor()
|
||||
try:
|
||||
|
||||
cursor.execute(sql)
|
||||
except Exception as e :
|
||||
print (e)
|
||||
print (sql)
|
||||
pass
|
||||
finally:
|
||||
cursor.close()
|
||||
# cursor.close()
|
||||
self.conn.commit()
|
||||
pass
|
||||
def write(self,info):
|
||||
"""
|
||||
:param info writes a list of data to a given set of fields
|
||||
|
@ -184,7 +221,7 @@ class SQLWriter(SQLRW,Writer):
|
|||
elif type(info) == dict :
|
||||
_fields = info.keys()
|
||||
elif type(info) == pd.DataFrame :
|
||||
_fields = info.columns
|
||||
_fields = info.columns.tolist()
|
||||
|
||||
# _fields = info.keys() if type(info) == dict else info[0].keys()
|
||||
_fields = list (_fields)
|
||||
|
@ -192,12 +229,13 @@ class SQLWriter(SQLRW,Writer):
|
|||
#
|
||||
# @TODO: Use pandas/odbc ? Not sure b/c it requires sqlalchemy
|
||||
#
|
||||
if type(info) != list :
|
||||
#
|
||||
# We are assuming 2 cases i.e dict or pd.DataFrame
|
||||
info = [info] if type(info) == dict else info.values.tolist()
|
||||
# if type(info) != list :
|
||||
# #
|
||||
# # We are assuming 2 cases i.e dict or pd.DataFrame
|
||||
# info = [info] if type(info) == dict else info.values.tolist()
|
||||
cursor = self.conn.cursor()
|
||||
try:
|
||||
|
||||
_sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",self.table) #.replace(":table",self.table).replace(":fields",_fields)
|
||||
if self._inspect :
|
||||
for _row in info :
|
||||
|
@ -223,34 +261,49 @@ class SQLWriter(SQLRW,Writer):
|
|||
|
||||
pass
|
||||
else:
|
||||
_fields = ",".join(self.fields)
|
||||
|
||||
# _sql = _sql.replace(":fields",_fields)
|
||||
# _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields]))
|
||||
# _sql = _sql.replace("(:fields)","")
|
||||
_sql = _sql.replace(":fields",_fields)
|
||||
values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
|
||||
_sql = _sql.replace(":values",values)
|
||||
if type(info) == pd.DataFrame :
|
||||
_info = info[self.fields].values.tolist()
|
||||
elif type(info) == dict :
|
||||
_info = info.values()
|
||||
else:
|
||||
# _info = []
|
||||
|
||||
# _sql = _sql.replace(":values",values)
|
||||
# if type(info) == pd.DataFrame :
|
||||
# _info = info[self.fields].values.tolist()
|
||||
|
||||
# elif type(info) == dict :
|
||||
# _info = info.values()
|
||||
# else:
|
||||
# # _info = []
|
||||
|
||||
_info = pd.DataFrame(info)[self.fields].values.tolist()
|
||||
# for row in info :
|
||||
|
||||
# if type(row) == dict :
|
||||
# _info.append( list(row.values()))
|
||||
cursor.executemany(_sql,_info)
|
||||
# _info = pd.DataFrame(info)[self.fields].values.tolist()
|
||||
# _info = pd.DataFrame(info).to_dict(orient='records')
|
||||
if type(info) == list :
|
||||
_info = pd.DataFrame(info)
|
||||
elif type(info) == dict :
|
||||
_info = pd.DataFrame([info])
|
||||
else:
|
||||
_info = pd.DataFrame(info)
|
||||
|
||||
|
||||
if self._engine :
|
||||
# pd.to_sql(_info,self._engine)
|
||||
_info.to_sql(self.table,self._engine,if_exists='append',index=False)
|
||||
else:
|
||||
_fields = ",".join(self.fields)
|
||||
_sql = _sql.replace(":fields",_fields)
|
||||
values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields])
|
||||
_sql = _sql.replace(":values",values)
|
||||
|
||||
cursor.executemany(_sql,_info.values.tolist())
|
||||
# cursor.commit()
|
||||
|
||||
# self.conn.commit()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
finally:
|
||||
self.conn.commit()
|
||||
cursor.close()
|
||||
self.conn.commit()
|
||||
# cursor.close()
|
||||
pass
|
||||
def close(self):
|
||||
try:
|
||||
|
@ -265,6 +318,7 @@ class BigQuery:
|
|||
self.path = path
|
||||
self.dtypes = _args['dtypes'] if 'dtypes' in _args else None
|
||||
self.table = _args['table'] if 'table' in _args else None
|
||||
self.client = bq.Client.from_service_account_json(self.path)
|
||||
def meta(self,**_args):
|
||||
"""
|
||||
This function returns meta data for a given table or query with dataset/table properly formatted
|
||||
|
@ -272,16 +326,16 @@ class BigQuery:
|
|||
:param sql sql query to be pulled,
|
||||
"""
|
||||
table = _args['table']
|
||||
client = bq.Client.from_service_account_json(self.path)
|
||||
ref = client.dataset(self.dataset).table(table)
|
||||
return client.get_table(ref).schema
|
||||
|
||||
ref = self.client.dataset(self.dataset).table(table)
|
||||
return self.client.get_table(ref).schema
|
||||
def has(self,**_args):
|
||||
found = False
|
||||
try:
|
||||
found = self.meta(**_args) is not None
|
||||
except Exception as e:
|
||||
pass
|
||||
return found
|
||||
return found
|
||||
class BQReader(BigQuery,Reader) :
|
||||
def __init__(self,**_args):
|
||||
|
||||
|
@ -304,8 +358,9 @@ class BQReader(BigQuery,Reader) :
|
|||
if (':dataset' in SQL or ':DATASET' in SQL) and self.dataset:
|
||||
SQL = SQL.replace(':dataset',self.dataset).replace(':DATASET',self.dataset)
|
||||
_info = {'credentials':self.credentials,'dialect':'standard'}
|
||||
return pd.read_gbq(SQL,**_info) if SQL else None
|
||||
# return pd.read_gbq(SQL,credentials=self.credentials,dialect='standard') if SQL else None
|
||||
return pd.read_gbq(SQL,**_info) if SQL else None
|
||||
# return self.client.query(SQL).to_dataframe() if SQL else None
|
||||
|
||||
|
||||
class BQWriter(BigQuery,Writer):
|
||||
lock = Lock()
|
||||
|
|
Loading…
Reference in New Issue