Skip to content

Commit

Permalink
[AIRFLOW-5500] Fix the trigger_dag api in the case of nested subdags
Browse files Browse the repository at this point in the history
Co-authored-by: Charles Bournhonesque <[email protected]>
(cherry picked from commit 16e06f8)
  • Loading branch information
cBournhonesque authored and kaxil committed Jun 28, 2020
1 parent 00e9cbf commit 78fd2cb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
12 changes: 4 additions & 8 deletions airflow/api/common/experimental/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,17 @@ def _trigger_dag(
else:
run_conf = json.loads(conf)

triggers = list()
dags_to_trigger = list()
dags_to_trigger.append(dag)
while dags_to_trigger:
dag = dags_to_trigger.pop()
trigger = dag.create_dagrun(
triggers = []
dags_to_trigger = [dag] + dag.subdags
for _dag in dags_to_trigger:
trigger = _dag.create_dagrun(
run_id=run_id,
execution_date=execution_date,
state=State.RUNNING,
conf=run_conf,
external_trigger=True,
)
triggers.append(trigger)
if dag.subdags:
dags_to_trigger.extend(dag.subdags)
return triggers


Expand Down
25 changes: 25 additions & 0 deletions tests/api/common/experimental/test_trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@ def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, dag_mock)

self.assertEqual(3, len(triggers))

@mock.patch('airflow.models.DAG')
@mock.patch('airflow.models.DagRun')
@mock.patch('airflow.models.DagBag')
def test_trigger_dag_include_nested_subdags(self, dag_bag_mock, dag_run_mock, dag_mock):
dag_id = "trigger_dag"
dag_bag_mock.dags = [dag_id]
dag_bag_mock.get_dag.return_value = dag_mock
dag_run_mock.find.return_value = None
dag1 = mock.MagicMock()
dag1.subdags = []
dag2 = mock.MagicMock()
dag2.subdags = [dag1]
dag_mock.subdags = [dag1, dag2]

triggers = _trigger_dag(
dag_id,
dag_bag_mock,
dag_run_mock,
run_id=None,
conf=None,
execution_date=None,
replace_microseconds=True)

self.assertEqual(3, len(triggers))

@mock.patch('airflow.models.DagBag')
def test_trigger_dag_with_str_conf(self, dag_bag_mock):
dag_id = "trigger_dag_with_str_conf"
Expand Down

0 comments on commit 78fd2cb

Please sign in to comment.