Skip to content

Commit

Permalink
Add method 'callproc' on Oracle hook (#20072)
Browse files Browse the repository at this point in the history
  • Loading branch information
malthe authored Dec 13, 2021
1 parent 8ac1b41 commit c7f36f2
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 4 deletions.
67 changes: 66 additions & 1 deletion airflow/providers/oracle/hooks/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,26 @@
# under the License.

from datetime import datetime
from typing import List, Optional
from typing import Dict, List, Optional, TypeVar

import cx_Oracle
import numpy

from airflow.hooks.dbapi import DbApiHook

PARAM_TYPES = {bool, float, int, str}

ParameterType = TypeVar('ParameterType', Dict, List, None)


def _map_param(value):
if value in PARAM_TYPES:
# In this branch, value is a Python type; calling it produces
# an instance of the type which is understood by the Oracle driver
# in the out parameter mapping mechanism.
value = value()
return value


class OracleHook(DbApiHook):
"""
Expand Down Expand Up @@ -266,3 +279,55 @@ def bulk_insert_rows(
self.log.info('[%s] inserted %s rows', table, row_count)
cursor.close()
conn.close() # type: ignore[attr-defined]

def callproc(
self,
identifier: str,
autocommit: bool = False,
parameters: ParameterType = None,
) -> ParameterType:
"""
Call the stored procedure identified by the provided string.
Any 'OUT parameters' must be provided with a value of either the
expected Python type (e.g., `int`) or an instance of that type.
The return value is a list or mapping that includes parameters in
both directions; the actual return type depends on the type of the
provided `parameters` argument.
See
https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var
for further reference.
"""
if parameters is None:
parameters = ()

args = ",".join(
f":{name}"
for name in (parameters if isinstance(parameters, dict) else range(1, len(parameters) + 1))
)

sql = f"BEGIN {identifier}({args}); END;"

def handler(cursor):
if isinstance(cursor.bindvars, list):
return [v.getvalue() for v in cursor.bindvars]

if isinstance(cursor.bindvars, dict):
return {n: v.getvalue() for (n, v) in cursor.bindvars.items()}

raise TypeError(f"Unexpected bindvars: {cursor.bindvars!r}")

result = self.run(
sql,
autocommit=autocommit,
parameters=(
{name: _map_param(value) for (name, value) in parameters.items()}
if isinstance(parameters, dict)
else [_map_param(value) for value in parameters]
),
handler=handler,
)

return result
40 changes: 38 additions & 2 deletions airflow/providers/oracle/operators/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterable, List, Mapping, Optional, Union
from typing import Dict, Iterable, List, Mapping, Optional, Union

from airflow.models import BaseOperator
from airflow.providers.oracle.hooks.oracle import OracleHook
Expand Down Expand Up @@ -62,4 +62,40 @@ def __init__(
def execute(self, context) -> None:
self.log.info('Executing: %s', self.sql)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
if self.sql:
hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)


class OracleStoredProcedureOperator(BaseOperator):
"""
Executes stored procedure in a specific Oracle database.
:param procedure: name of stored procedure to call (templated)
:type procedure: str
:param oracle_conn_id: The :ref:`Oracle connection id <howto/connection:oracle>`
reference to a specific Oracle database.
:type oracle_conn_id: str
:param parameters: (optional) the parameters provided in the call
:type parameters: dict or iterable
"""

template_fields = ('procedure',)
ui_color = '#ededed'

def __init__(
self,
*,
procedure: str,
oracle_conn_id: str = 'oracle_default',
parameters: Optional[Union[Dict, List]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.oracle_conn_id = oracle_conn_id
self.procedure = procedure
self.parameters = parameters

def execute(self, context) -> None:
self.log.info('Executing: %s', self.procedure)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters)
38 changes: 38 additions & 0 deletions tests/providers/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,41 @@ def test_bulk_insert_rows_no_rows(self):
rows = []
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows('table', rows)

def test_callproc_dict(self):
parameters = {"a": 1, "b": 2, "c": 3}

class bindvar(int):
def getvalue(self):
return self

self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()}
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)]
assert result == parameters

def test_callproc_list(self):
parameters = [1, 2, 3]

class bindvar(int):
def getvalue(self):
return self

self.cur.bindvars = list(map(bindvar, parameters))
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)]
assert result == parameters

def test_callproc_out_param(self):
parameters = [1, int, float, bool, str]

def bindvar(value):
m = mock.Mock()
m.getvalue.return_value = value
return m

self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters]
result = self.db_hook.callproc('proc', True, parameters)
expected = [1, 0, 0.0, False, '']
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)]
assert result == expected
28 changes: 27 additions & 1 deletion tests/providers/oracle/operators/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from unittest import mock

from airflow.providers.oracle.hooks.oracle import OracleHook
from airflow.providers.oracle.operators.oracle import OracleOperator
from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator


class TestOracleOperator(unittest.TestCase):
Expand All @@ -46,3 +46,29 @@ def test_execute(self, mock_run):
autocommit=autocommit,
parameters=parameters,
)


class TestOracleStoredProcedureOperator(unittest.TestCase):
@mock.patch.object(OracleHook, 'run', autospec=OracleHook.run)
def test_execute(self, mock_run):
procedure = 'test'
oracle_conn_id = 'oracle_default'
parameters = {'parameter': 'value'}
context = "test_context"
task_id = "test_task_id"

operator = OracleStoredProcedureOperator(
procedure=procedure,
oracle_conn_id=oracle_conn_id,
parameters=parameters,
task_id=task_id,
)
result = operator.execute(context=context)
assert result is mock_run.return_value
mock_run.assert_called_once_with(
mock.ANY,
'BEGIN test(:parameter); END;',
autocommit=True,
parameters=parameters,
handler=mock.ANY,
)

0 comments on commit c7f36f2

Please sign in to comment.