Skip to content

Commit

Permalink
[AIRFLOW-6685] ThresholdCheckOperator (#7353)
Browse files Browse the repository at this point in the history
* [AIRFLOW-6685] Data Quality Check operators

* removed .get_connection to get hook in get_sql_value

* added tests for get_sql_value

* threshold check operator and tests added to checkoperator file
  • Loading branch information
alexzlue authored Mar 30, 2020
1 parent 6018532 commit 4c6ae18
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 1 deletion.
84 changes: 84 additions & 0 deletions airflow/operators/check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,87 @@ def get_db_hook(self):
:rtype: DbApiHook
"""
return BaseHook.get_hook(conn_id=self.conn_id)


class ThresholdCheckOperator(BaseOperator):
"""
Performs a value check using sql code against a mininmum threshold
and a maximum threshold. Thresholds can be in the form of a numeric
value OR a sql statement that results a numeric.
Note that this is an abstract class and get_db_hook
needs to be defined. Whereas a get_db_hook is hook that gets a
single record from an external source.
:param sql: the sql to be executed. (templated)
:type sql: str
:param min_threshold: numerical value or min threshold sql to be executed (templated)
:type min_threshold: numeric or str
:param max_threshold: numerical value or max threshold sql to be executed (templated)
:type max_threshold: numeric or str
"""

template_fields = ('sql', 'min_threshold', 'max_threshold') # type: Iterable[str]
template_ext = ('.hql', '.sql',) # type: Iterable[str]

@apply_defaults
def __init__(
self,
sql: str,
min_threshold: Any,
max_threshold: Any,
conn_id: Optional[str] = None,
*args, **kwargs
):
super().__init__(*args, **kwargs)
self.sql = sql
self.conn_id = conn_id
self.min_threshold = _convert_to_float_if_possible(min_threshold)
self.max_threshold = _convert_to_float_if_possible(max_threshold)

def execute(self, context=None):
hook = self.get_db_hook()
result = hook.get_first(self.sql)[0][0]

if isinstance(self.min_threshold, float):
lower_bound = self.min_threshold
else:
lower_bound = hook.get_first(self.min_threshold)[0][0]

if isinstance(self.max_threshold, float):
upper_bound = self.max_threshold
else:
upper_bound = hook.get_first(self.max_threshold)[0][0]

meta_data = {
"result": result,
"task_id": self.task_id,
"min_threshold": lower_bound,
"max_threshold": upper_bound,
"within_threshold": lower_bound <= result <= upper_bound
}

self.push(meta_data)
if not meta_data["within_threshold"]:
error_msg = (f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
f'Check description: {meta_data.get("description")}\n'
f'SQL: {self.sql}\n'
f'Result: {round(meta_data.get("result"), 2)} is not within thresholds '
f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
)
raise AirflowException(error_msg)

self.log.info("Test %s Successful.", self.task_id)

def push(self, meta_data):
"""
Optional: Send data check info and metadata to an external database.
Default functionality will log metadata.
"""

info = "\n".join([f"""{key}: {item}""" for key, item in meta_data.items()])
self.log.info("Log from %s:\n%s", self.dag_id, info)

def get_db_hook(self):
return BaseHook.get_hook(conn_id=self.conn_id)
105 changes: 104 additions & 1 deletion tests/operators/test_check_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@

from airflow.exceptions import AirflowException
from airflow.models import DAG
from airflow.operators.check_operator import CheckOperator, IntervalCheckOperator, ValueCheckOperator
from airflow.operators.check_operator import (
CheckOperator, IntervalCheckOperator, ThresholdCheckOperator, ValueCheckOperator,
)


class TestCheckOperator(unittest.TestCase):
Expand Down Expand Up @@ -219,3 +221,104 @@ def returned_row():

with self.assertRaisesRegex(AirflowException, "f0, f1"):
operator.execute()


class TestThresholdCheckOperator(unittest.TestCase):

def _construct_operator(self, sql, min_threshold, max_threshold):
dag = DAG('test_dag', start_date=datetime(2017, 1, 1))

return ThresholdCheckOperator(
task_id='test_task',
sql=sql,
min_threshold=min_threshold,
max_threshold=max_threshold,
dag=dag
)

@mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
def test_pass_min_value_max_value(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.return_value = [(10,)]
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator(
'Select avg(val) from table1 limit 1',
1,
100
)

operator.execute()

@mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
def test_fail_min_value_max_value(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.return_value = [(10,)]
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator(
'Select avg(val) from table1 limit 1',
20,
100
)

with self.assertRaisesRegex(AirflowException, '10.*20.0.*100.0'):
operator.execute()

@mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
def test_pass_min_sql_max_sql(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator(
'Select 10',
'Select 1',
'Select 100'
)

operator.execute()

@mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
def test_fail_min_sql_max_sql(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator(
'Select 10',
'Select 20',
'Select 100'
)

with self.assertRaisesRegex(AirflowException, '10.*20.*100'):
operator.execute()

@mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
def test_pass_min_value_max_sql(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator(
'Select 75',
45,
'Select 100'
)

operator.execute()

@mock.patch.object(ThresholdCheckOperator, 'get_db_hook')
def test_fail_min_sql_max_value(self, mock_get_db_hook):
mock_hook = mock.Mock()
mock_hook.get_first.side_effect = lambda x: [(int(x.split()[1]),)]
mock_get_db_hook.return_value = mock_hook

operator = self._construct_operator(
'Select 155',
'Select 45',
100
)

with self.assertRaisesRegex(AirflowException, '155.*45.*100.0'):
operator.execute()

0 comments on commit 4c6ae18

Please sign in to comment.