diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py new file mode 100644 index 00000000000000..072c40cde9f41e --- /dev/null +++ b/airflow/operators/sql_branch_operator.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from distutils.util import strtobool + +from airflow.exceptions import AirflowException +from airflow.hooks.base_hook import BaseHook +from airflow.models import BaseOperator, SkipMixin +from airflow.utils.decorators import apply_defaults + +ALLOWED_CONN_TYPE = { + "google_cloud_platform", + "jdbc", + "mssql", + "mysql", + "odbc", + "oracle", + "postgres", + "presto", + "sqlite", + "vertica", +} + + +class BranchSqlOperator(BaseOperator, SkipMixin): + """ + Executes sql code in a specific database + + :param sql: the sql code to be executed. (templated) + :type sql: Can receive a str representing a sql statement or reference to a template file. + Template reference are recognized by str ending in '.sql'. + Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1) + or string (true/y/yes/1/on/false/n/no/0/off). + :param follow_task_ids_if_true: task id or task ids to follow if query return true + :type follow_task_ids_if_true: str or list + :param follow_task_ids_if_false: task id or task ids to follow if query return true + :type follow_task_ids_if_false: str or list + :param conn_id: reference to a specific database + :type conn_id: str + :param database: name of database which overwrite defined one in connection + :param parameters: (optional) the parameters to render the SQL query with. + :type parameters: mapping or iterable + """ + + template_fields = ("sql",) + template_ext = (".sql",) + ui_color = "#a22034" + ui_fgcolor = "#F7F7F7" + + @apply_defaults + def __init__( + self, + sql, + follow_task_ids_if_true, + follow_task_ids_if_false, + conn_id="default_conn_id", + database=None, + parameters=None, + *args, + **kwargs): + super(BranchSqlOperator, self).__init__(*args, **kwargs) + self.conn_id = conn_id + self.sql = sql + self.parameters = parameters + self.follow_task_ids_if_true = follow_task_ids_if_true + self.follow_task_ids_if_false = follow_task_ids_if_false + self.database = database + self._hook = None + + def _get_hook(self): + self.log.debug("Get connection for %s", self.conn_id) + conn = BaseHook.get_connection(self.conn_id) + + if conn.conn_type not in ALLOWED_CONN_TYPE: + raise AirflowException( + "The connection type is not supported by BranchSqlOperator. " + + "Supported connection types: {}".format(list(ALLOWED_CONN_TYPE)) + ) + + if not self._hook: + self._hook = conn.get_hook() + if self.database: + self._hook.schema = self.database + + return self._hook + + def execute(self, context): + # get supported hook + self._hook = self._get_hook() + + if self._hook is None: + raise AirflowException( + "Failed to establish connection to '%s'" % self.conn_id + ) + + if self.sql is None: + raise AirflowException("Expected 'sql' parameter is missing.") + + if self.follow_task_ids_if_true is None: + raise AirflowException( + "Expected 'follow_task_ids_if_true' paramter is missing." + ) + + if self.follow_task_ids_if_false is None: + raise AirflowException( + "Expected 'follow_task_ids_if_false' parameter is missing." + ) + + self.log.info( + "Executing: %s (with parameters %s) with connection: %s", + self.sql, + self.parameters, + self._hook, + ) + record = self._hook.get_first(self.sql, self.parameters) + if not record: + raise AirflowException( + "No rows returned from sql query. Operator expected True or False return value." + ) + + if isinstance(record, list): + if isinstance(record[0], list): + query_result = record[0][0] + else: + query_result = record[0] + elif isinstance(record, tuple): + query_result = record[0] + else: + query_result = record + + self.log.info("Query returns %s, type '%s'", query_result, type(query_result)) + + follow_branch = None + try: + if isinstance(query_result, bool): + if query_result: + follow_branch = self.follow_task_ids_if_true + elif isinstance(query_result, str): + # return result is not Boolean, try to convert from String to Boolean + if bool(strtobool(query_result)): + follow_branch = self.follow_task_ids_if_true + elif isinstance(query_result, int): + if bool(query_result): + follow_branch = self.follow_task_ids_if_true + else: + raise AirflowException( + "Unexpected query return result '%s' type '%s'" + % (query_result, type(query_result)) + ) + + if follow_branch is None: + follow_branch = self.follow_task_ids_if_false + except ValueError: + raise AirflowException( + "Unexpected query return result '%s' type '%s'" + % (query_result, type(query_result)) + ) + + self.skip_all_except(context["ti"], follow_branch) diff --git a/tests/operators/test_sql_branch_operator.py b/tests/operators/test_sql_branch_operator.py new file mode 100644 index 00000000000000..6510609991b7ca --- /dev/null +++ b/tests/operators/test_sql_branch_operator.py @@ -0,0 +1,476 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime +import unittest +from tests.compat import mock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.models import DAG, DagRun, TaskInstance as TI +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.sql_branch_operator import BranchSqlOperator +from airflow.utils import timezone +from airflow.utils.db import create_session +from airflow.utils.state import State +from tests.hooks.test_hive_hook import TestHiveEnvironment + +DEFAULT_DATE = timezone.datetime(2016, 1, 1) +INTERVAL = datetime.timedelta(hours=12) + +SUPPORTED_TRUE_VALUES = [ + ["True"], + ["true"], + ["1"], + ["on"], + [1], + True, + "true", + "1", + "on", + 1, +] +SUPPORTED_FALSE_VALUES = [ + ["False"], + ["false"], + ["0"], + ["off"], + [0], + False, + "false", + "0", + "off", + 0, +] + + +class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): + """ + Test for SQL Branch Operator + """ + + @classmethod + def setUpClass(cls): + super(TestSqlBranch, cls).setUpClass() + + with create_session() as session: + session.query(DagRun).delete() + session.query(TI).delete() + + def setUp(self): + super(TestSqlBranch, self).setUp() + self.dag = DAG( + "sql_branch_operator_test", + default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, + schedule_interval=INTERVAL, + ) + self.branch_1 = DummyOperator(task_id="branch_1", dag=self.dag) + self.branch_2 = DummyOperator(task_id="branch_2", dag=self.dag) + self.branch_3 = None + + def tearDown(self): + super(TestSqlBranch, self).tearDown() + + with create_session() as session: + session.query(DagRun).delete() + session.query(TI).delete() + + def test_unsupported_conn_type(self): + """ Check if BranchSqlOperator throws an exception for unsupported connection type """ + op = BranchSqlOperator( + task_id="make_choice", + conn_id="redis_default", + sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + with self.assertRaises(AirflowException): + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_invalid_conn(self): + """ Check if BranchSqlOperator throws an exception for invalid connection """ + op = BranchSqlOperator( + task_id="make_choice", + conn_id="invalid_connection", + sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + with self.assertRaises(AirflowException): + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_invalid_follow_task_true(self): + """ Check if BranchSqlOperator throws an exception for invalid connection """ + op = BranchSqlOperator( + task_id="make_choice", + conn_id="invalid_connection", + sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", + follow_task_ids_if_true=None, + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + with self.assertRaises(AirflowException): + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_invalid_follow_task_false(self): + """ Check if BranchSqlOperator throws an exception for invalid connection """ + op = BranchSqlOperator( + task_id="make_choice", + conn_id="invalid_connection", + sql="SELECT count(1) FROM INFORMATION_SCHEMA.TABLES", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false=None, + dag=self.dag, + ) + + with self.assertRaises(AirflowException): + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + @pytest.mark.backend("mysql") + def test_sql_branch_operator_mysql(self): + """ Check if BranchSqlOperator works with backend """ + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="mysql_default", + sql="SELECT 1", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + branch_op.run( + start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True + ) + + @pytest.mark.backend("postgres") + def test_sql_branch_operator_postgres(self): + """ Check if BranchSqlOperator works with backend """ + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="postgres_default", + sql="SELECT 1", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + branch_op.run( + start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True + ) + + @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + def test_branch_single_value_with_dag_run(self, mock_hook): + """ Check BranchSqlOperator branch operation """ + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="mysql_default", + sql="SELECT 1", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + self.branch_1.set_upstream(branch_op) + self.branch_2.set_upstream(branch_op) + self.dag.clear() + + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + + mock_hook.get_connection("mysql_default").conn_type = "mysql" + mock_get_records = ( + mock_hook.get_connection.return_value.get_hook.return_value.get_first + ) + + mock_get_records.return_value = 1 + + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == "make_choice": + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == "branch_1": + self.assertEqual(ti.state, State.NONE) + elif ti.task_id == "branch_2": + self.assertEqual(ti.state, State.SKIPPED) + else: + raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) + + @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + def test_branch_true_with_dag_run(self, mock_hook): + """ Check BranchSqlOperator branch operation """ + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="mysql_default", + sql="SELECT 1", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + self.branch_1.set_upstream(branch_op) + self.branch_2.set_upstream(branch_op) + self.dag.clear() + + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + + mock_hook.get_connection("mysql_default").conn_type = "mysql" + mock_get_records = ( + mock_hook.get_connection.return_value.get_hook.return_value.get_first + ) + + for true_value in SUPPORTED_TRUE_VALUES: + mock_get_records.return_value = true_value + + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == "make_choice": + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == "branch_1": + self.assertEqual(ti.state, State.NONE) + elif ti.task_id == "branch_2": + self.assertEqual(ti.state, State.SKIPPED) + else: + raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) + + @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + def test_branch_false_with_dag_run(self, mock_hook): + """ Check BranchSqlOperator branch operation """ + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="mysql_default", + sql="SELECT 1", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + self.branch_1.set_upstream(branch_op) + self.branch_2.set_upstream(branch_op) + self.dag.clear() + + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + + mock_hook.get_connection("mysql_default").conn_type = "mysql" + mock_get_records = ( + mock_hook.get_connection.return_value.get_hook.return_value.get_first + ) + + for false_value in SUPPORTED_FALSE_VALUES: + mock_get_records.return_value = false_value + + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == "make_choice": + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == "branch_1": + self.assertEqual(ti.state, State.SKIPPED) + elif ti.task_id == "branch_2": + self.assertEqual(ti.state, State.NONE) + else: + raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) + + @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + def test_branch_list_with_dag_run(self, mock_hook): + """ Checks if the BranchSqlOperator supports branching off to a list of tasks.""" + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="mysql_default", + sql="SELECT 1", + follow_task_ids_if_true=["branch_1", "branch_2"], + follow_task_ids_if_false="branch_3", + dag=self.dag, + ) + + self.branch_1.set_upstream(branch_op) + self.branch_2.set_upstream(branch_op) + self.branch_3 = DummyOperator(task_id="branch_3", dag=self.dag) + self.branch_3.set_upstream(branch_op) + self.dag.clear() + + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + + mock_hook.get_connection("mysql_default").conn_type = "mysql" + mock_get_records = ( + mock_hook.get_connection.return_value.get_hook.return_value.get_first + ) + mock_get_records.return_value = [["1"]] + + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == "make_choice": + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == "branch_1": + self.assertEqual(ti.state, State.NONE) + elif ti.task_id == "branch_2": + self.assertEqual(ti.state, State.NONE) + elif ti.task_id == "branch_3": + self.assertEqual(ti.state, State.SKIPPED) + else: + raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) + + @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + def test_invalid_query_result_with_dag_run(self, mock_hook): + """ Check BranchSqlOperator branch operation """ + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="mysql_default", + sql="SELECT 1", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + self.branch_1.set_upstream(branch_op) + self.branch_2.set_upstream(branch_op) + self.dag.clear() + + self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + + mock_hook.get_connection("mysql_default").conn_type = "mysql" + mock_get_records = ( + mock_hook.get_connection.return_value.get_hook.return_value.get_first + ) + + mock_get_records.return_value = ["Invalid Value"] + + with self.assertRaises(AirflowException): + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + def test_with_skip_in_branch_downstream_dependencies(self, mock_hook): + """ Test SQL Branch with skipping all downstream dependencies """ + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="mysql_default", + sql="SELECT 1", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + branch_op >> self.branch_1 >> self.branch_2 + branch_op >> self.branch_2 + self.dag.clear() + + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + + mock_hook.get_connection("mysql_default").conn_type = "mysql" + mock_get_records = ( + mock_hook.get_connection.return_value.get_hook.return_value.get_first + ) + + for true_value in SUPPORTED_TRUE_VALUES: + mock_get_records.return_value = [true_value] + + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == "make_choice": + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == "branch_1": + self.assertEqual(ti.state, State.NONE) + elif ti.task_id == "branch_2": + self.assertEqual(ti.state, State.NONE) + else: + raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id)) + + @mock.patch("airflow.operators.sql_branch_operator.BaseHook") + def test_with_skip_in_branch_downstream_dependencies2(self, mock_hook): + """ Test skipping downstream dependency for false condition""" + branch_op = BranchSqlOperator( + task_id="make_choice", + conn_id="mysql_default", + sql="SELECT 1", + follow_task_ids_if_true="branch_1", + follow_task_ids_if_false="branch_2", + dag=self.dag, + ) + + branch_op >> self.branch_1 >> self.branch_2 + branch_op >> self.branch_2 + self.dag.clear() + + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + + mock_hook.get_connection("mysql_default").conn_type = "mysql" + mock_get_records = ( + mock_hook.get_connection.return_value.get_hook.return_value.get_first + ) + + for false_value in SUPPORTED_FALSE_VALUES: + mock_get_records.return_value = [false_value] + + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == "make_choice": + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == "branch_1": + self.assertEqual(ti.state, State.SKIPPED) + elif ti.task_id == "branch_2": + self.assertEqual(ti.state, State.NONE) + else: + raise ValueError("Invalid task id {task_id} found!".format(task_id=ti.task_id))