diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py index 057ec1aa98c0a..f07919777cfae 100644 --- a/airflow/providers/oracle/hooks/oracle.py +++ b/airflow/providers/oracle/hooks/oracle.py @@ -17,13 +17,26 @@ # under the License. from datetime import datetime -from typing import List, Optional +from typing import Dict, List, Optional, TypeVar import cx_Oracle import numpy from airflow.hooks.dbapi import DbApiHook +PARAM_TYPES = {bool, float, int, str} + +ParameterType = TypeVar('ParameterType', Dict, List, None) + + +def _map_param(value): + if value in PARAM_TYPES: + # In this branch, value is a Python type; calling it produces + # an instance of the type which is understood by the Oracle driver + # in the out parameter mapping mechanism. + value = value() + return value + class OracleHook(DbApiHook): """ @@ -266,3 +279,55 @@ def bulk_insert_rows( self.log.info('[%s] inserted %s rows', table, row_count) cursor.close() conn.close() # type: ignore[attr-defined] + + def callproc( + self, + identifier: str, + autocommit: bool = False, + parameters: ParameterType = None, + ) -> ParameterType: + """ + Call the stored procedure identified by the provided string. + + Any 'OUT parameters' must be provided with a value of either the + expected Python type (e.g., `int`) or an instance of that type. + + The return value is a list or mapping that includes parameters in + both directions; the actual return type depends on the type of the + provided `parameters` argument. + + See + https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var + for further reference. + """ + if parameters is None: + parameters = () + + args = ",".join( + f":{name}" + for name in (parameters if isinstance(parameters, dict) else range(1, len(parameters) + 1)) + ) + + sql = f"BEGIN {identifier}({args}); END;" + + def handler(cursor): + if isinstance(cursor.bindvars, list): + return [v.getvalue() for v in cursor.bindvars] + + if isinstance(cursor.bindvars, dict): + return {n: v.getvalue() for (n, v) in cursor.bindvars.items()} + + raise TypeError(f"Unexpected bindvars: {cursor.bindvars!r}") + + result = self.run( + sql, + autocommit=autocommit, + parameters=( + {name: _map_param(value) for (name, value) in parameters.items()} + if isinstance(parameters, dict) + else [_map_param(value) for value in parameters] + ), + handler=handler, + ) + + return result diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py index dcc07a207fbaa..b80d570b7dc95 100644 --- a/airflow/providers/oracle/operators/oracle.py +++ b/airflow/providers/oracle/operators/oracle.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Iterable, List, Mapping, Optional, Union +from typing import Dict, Iterable, List, Mapping, Optional, Union from airflow.models import BaseOperator from airflow.providers.oracle.hooks.oracle import OracleHook @@ -62,4 +62,40 @@ def __init__( def execute(self, context) -> None: self.log.info('Executing: %s', self.sql) hook = OracleHook(oracle_conn_id=self.oracle_conn_id) - hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + if self.sql: + hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters) + + +class OracleStoredProcedureOperator(BaseOperator): + """ + Executes stored procedure in a specific Oracle database. + + :param procedure: name of stored procedure to call (templated) + :type procedure: str + :param oracle_conn_id: The :ref:`Oracle connection id ` + reference to a specific Oracle database. + :type oracle_conn_id: str + :param parameters: (optional) the parameters provided in the call + :type parameters: dict or iterable + """ + + template_fields = ('procedure',) + ui_color = '#ededed' + + def __init__( + self, + *, + procedure: str, + oracle_conn_id: str = 'oracle_default', + parameters: Optional[Union[Dict, List]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.oracle_conn_id = oracle_conn_id + self.procedure = procedure + self.parameters = parameters + + def execute(self, context) -> None: + self.log.info('Executing: %s', self.procedure) + hook = OracleHook(oracle_conn_id=self.oracle_conn_id) + return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters) diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index 0101a34b83c07..0f5c7dfa96e96 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -291,3 +291,41 @@ def test_bulk_insert_rows_no_rows(self): rows = [] with pytest.raises(ValueError): self.db_hook.bulk_insert_rows('table', rows) + + def test_callproc_dict(self): + parameters = {"a": 1, "b": 2, "c": 3} + + class bindvar(int): + def getvalue(self): + return self + + self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()} + result = self.db_hook.callproc('proc', True, parameters) + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)] + assert result == parameters + + def test_callproc_list(self): + parameters = [1, 2, 3] + + class bindvar(int): + def getvalue(self): + return self + + self.cur.bindvars = list(map(bindvar, parameters)) + result = self.db_hook.callproc('proc', True, parameters) + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)] + assert result == parameters + + def test_callproc_out_param(self): + parameters = [1, int, float, bool, str] + + def bindvar(value): + m = mock.Mock() + m.getvalue.return_value = value + return m + + self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters] + result = self.db_hook.callproc('proc', True, parameters) + expected = [1, 0, 0.0, False, ''] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)] + assert result == expected diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py index 8565efe6aea46..40359f694fa70 100644 --- a/tests/providers/oracle/operators/test_oracle.py +++ b/tests/providers/oracle/operators/test_oracle.py @@ -19,7 +19,7 @@ from unittest import mock from airflow.providers.oracle.hooks.oracle import OracleHook -from airflow.providers.oracle.operators.oracle import OracleOperator +from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator class TestOracleOperator(unittest.TestCase): @@ -46,3 +46,29 @@ def test_execute(self, mock_run): autocommit=autocommit, parameters=parameters, ) + + +class TestOracleStoredProcedureOperator(unittest.TestCase): + @mock.patch.object(OracleHook, 'run', autospec=OracleHook.run) + def test_execute(self, mock_run): + procedure = 'test' + oracle_conn_id = 'oracle_default' + parameters = {'parameter': 'value'} + context = "test_context" + task_id = "test_task_id" + + operator = OracleStoredProcedureOperator( + procedure=procedure, + oracle_conn_id=oracle_conn_id, + parameters=parameters, + task_id=task_id, + ) + result = operator.execute(context=context) + assert result is mock_run.return_value + mock_run.assert_called_once_with( + mock.ANY, + 'BEGIN test(:parameter); END;', + autocommit=True, + parameters=parameters, + handler=mock.ANY, + )