From 105ff00224244ec1a0737d0120c70ff91441885a Mon Sep 17 00:00:00 2001 From: Steve Nyemba Date: Mon, 7 Mar 2022 18:50:29 -0600 Subject: [PATCH] bugfix: ETL multiprocessing --- bin/transport | 34 +++++++++------ transport/__init__.py | 2 +- transport/sql.py | 97 +++++++++++++++---------------------------- 3 files changed, 57 insertions(+), 76 deletions(-) diff --git a/bin/transport b/bin/transport index 1df4c03..47979db 100755 --- a/bin/transport +++ b/bin/transport @@ -75,10 +75,10 @@ class Post(Process): _info = {"values":self.rows} if 'couch' in self.PROVIDER else self.rows ltypes = self.rows.dtypes.values columns = self.rows.dtypes.index.tolist() - if not self.writer.has() : + # if not self.writer.has() : - self.writer.make(fields=columns) + # self.writer.make(fields=columns) # self.log(module='write',action='make-table',input={"name":self.writer.table}) for name in columns : if _info[name].dtype in ['int32','int64','int','float','float32','float64'] : @@ -86,7 +86,7 @@ class Post(Process): else: value = '' _info[name] = _info[name].fillna(value) - print (_info) + self.writer.write(_info) self.writer.close() @@ -94,6 +94,7 @@ class Post(Process): class ETL (Process): def __init__(self,**_args): super().__init__() + self.name = _args['id'] if 'provider' not in _args['source'] : #@deprecate @@ -133,18 +134,24 @@ class ETL (Process): self.log(module='write',action='partitioning') - rows = np.array_split(np.arange(idf.shape[0]),self.JOB_COUNT) + rows = np.array_split(np.arange(0,idf.shape[0]),self.JOB_COUNT) + # # @TODO: locks - for i in rows : - _id = 'segment #'.join([str(rows.index(i)),self.name]) - segment = idf.loc[i,:] #.to_dict(orient='records') + for i in np.arange(self.JOB_COUNT) : + print () + print (i) + _id = 'segment # '.join([str(i),' ',self.name]) + indexes = rows[i] + segment = idf.loc[indexes,:].copy() #.to_dict(orient='records') proc = Post(target = self._oargs,rows = segment,name=_id) self.jobs.append(proc) proc.start() - self.log(module='write',action='working ...',name=self.name) - + self.log(module='write',action='working',segment=_id) + # while poc : + # proc = [job for job in proc if job.is_alive()] + # time.sleep(1) except Exception as e: print (e) @@ -168,13 +175,16 @@ if __name__ == '__main__' : if 'source' in SYS_ARGS : _config['source'] = {"type":"disk.DiskReader","args":{"path":SYS_ARGS['source'],"delimiter":","}} - _config['jobs'] = 10 if 'jobs' not in SYS_ARGS else int(SYS_ARGS['jobs']) + _config['jobs'] = 3 if 'jobs' not in SYS_ARGS else int(SYS_ARGS['jobs']) etl = ETL (**_config) - if not index : + if index is None: etl.start() procs.append(etl) - if index and _info.index(_config) == index : + + elif _info.index(_config) == index : + + # print (_config) procs = [etl] etl.start() break diff --git a/transport/__init__.py b/transport/__init__.py index ce5090b..6642c4e 100644 --- a/transport/__init__.py +++ b/transport/__init__.py @@ -162,7 +162,7 @@ def instance(**_args): if provider not in ['mongodb','couchdb','bigquery'] : uri = ''.join([provider,"://",account,host,'/',database]) - e = sqlalchemy.create_engine (uri) + e = sqlalchemy.create_engine (uri,future=True) args['sqlalchemy'] = e # # @TODO: Include handling of bigquery with SQLAlchemy diff --git a/transport/sql.py b/transport/sql.py index d408942..6d44976 100644 --- a/transport/sql.py +++ b/transport/sql.py @@ -21,7 +21,7 @@ else: import json from google.oauth2 import service_account from google.cloud import bigquery as bq -from multiprocessing import Lock +from multiprocessing import Lock, RLock import pandas as pd import numpy as np import nzpy as nz #--- netezza drivers @@ -30,7 +30,7 @@ import os class SQLRW : - + lock = RLock() DRIVERS = {"postgresql":pg,"redshift":pg,"mysql":my,"mariadb":my,"netezza":nz} REFERENCE = { "netezza":{"port":5480,"handler":nz,"dtype":"VARCHAR(512)"}, @@ -71,7 +71,7 @@ class SQLRW : # _handler = SQLWriter.REFERENCE[_provider]['handler'] _handler = _args['driver'] #-- handler to the driver self._dtype = _args['default']['type'] if 'default' in _args and 'type' in _args['default'] else 'VARCHAR(256)' - self._provider = _args['provider'] + # self._provider = _args['provider'] # self._dtype = SQLWriter.REFERENCE[_provider]['dtype'] if 'dtype' not in _args else _args['dtype'] # self._provider = _provider if _handler == nz : @@ -173,7 +173,7 @@ class SQLWriter(SQLRW,Writer): # In the advent that data typing is difficult to determine we can inspect and perform a default case # This slows down the process but improves reliability of the data # NOTE: Proper data type should be set on the target system if their source is unclear. - self._inspect = False if 'inspect' not in _args else _args['inspect'] + self._cast = False if 'cast' not in _args else _args['cast'] def init(self,fields=None): @@ -244,78 +244,49 @@ class SQLWriter(SQLRW,Writer): # # # # We are assuming 2 cases i.e dict or pd.DataFrame # info = [info] if type(info) == dict else info.values.tolist() - cursor = self.conn.cursor() + try: table = self._tablename(self.table) _sql = "INSERT INTO :table (:fields) VALUES (:values)".replace(":table",table) #.replace(":table",self.table).replace(":fields",_fields) - if self._inspect : - for _row in info : - fields = list(_row.keys()) - if self._cast == False : - values = ",".join(_row.values()) - else: - # values = "'"+"','".join([str(value) for value in _row.values()])+"'" - values = [",".join(["%(",name,")s"]) for name in _row.keys()] - - # values = [ "".join(["'",str(_row[key]),"'"]) if np.nan(_row[key]).isnumeric() else str(_row[key]) for key in _row] - # print (values) - query = _sql.replace(":fields",",".join(fields)).replace(":values",values) - if type(info) == pd.DataFrame : - _values = info.values.tolist() - elif type(info) == list and type(info[0]) == dict: - print ('........') - _values = [tuple(item.values()) for item in info] - else: - _values = info; - cursor.execute(query,_values) - - - pass + + if type(info) == list : + _info = pd.DataFrame(info) + elif type(info) == dict : + _info = pd.DataFrame([info]) else: - - # _sql = _sql.replace(":fields",_fields) - # _sql = _sql.replace(":values",",".join(["%("+name+")s" for name in self.fields])) - # _sql = _sql.replace("(:fields)","") - - # _sql = _sql.replace(":values",values) - # if type(info) == pd.DataFrame : - # _info = info[self.fields].values.tolist() - - # elif type(info) == dict : - # _info = info.values() - # else: - # # _info = [] - - # _info = pd.DataFrame(info)[self.fields].values.tolist() - # _info = pd.DataFrame(info).to_dict(orient='records') - if type(info) == list : - _info = pd.DataFrame(info) - elif type(info) == dict : - _info = pd.DataFrame([info]) - else: - _info = pd.DataFrame(info) + _info = pd.DataFrame(info) + + if _info.shape[0] == 0 : - if self._engine : - # pd.to_sql(_info,self._engine) - - rows = _info.to_sql(table,self._engine,schema=self.schema,if_exists='append',index=False) - + return + SQLRW.lock.acquire() + if self._engine is not None: + # pd.to_sql(_info,self._engine) + if self.schema in ['',None] : + rows = _info.to_sql(table,self._engine,if_exists='append',index=False) else: - _fields = ",".join(self.fields) - _sql = _sql.replace(":fields",_fields) - values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields]) - _sql = _sql.replace(":values",values) - - cursor.executemany(_sql,_info.values.tolist()) - # cursor.commit() + rows = _info.to_sql(self.table,self._engine,schema=self.schema,if_exists='append',index=False) + + else: + _fields = ",".join(self.fields) + _sql = _sql.replace(":fields",_fields) + values = ", ".join("?"*len(self.fields)) if self._provider == 'netezza' else ",".join(["%s" for name in self.fields]) + _sql = _sql.replace(":values",values) + cursor = self.conn.cursor() + cursor.executemany(_sql,_info.values.tolist()) + cursor.close() + # cursor.commit() # self.conn.commit() except Exception as e: print(e) pass finally: - self.conn.commit() + + if self._engine is None : + self.conn.commit() + SQLRW.lock.release() # cursor.close() pass def close(self):