Skip to content

Commit

Permalink
[AIRFLOW-4013] Fix Mark Success/Failed picking all execution_date bug (
Browse files Browse the repository at this point in the history
  • Loading branch information
yuqian90 authored and Jing Guo committed Sep 2, 2019
1 parent 3f80803 commit 271c184
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
5 changes: 5 additions & 0 deletions airflow/api/common/experimental/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ def get_execution_dates(dag, execution_date, future, past):
start_date = execution_date if not past else start_date
if dag.schedule_interval == '@once':
dates = [start_date]
elif not dag.schedule_interval:
# If schedule_interval is None, need to look at existing DagRun if the user wants future or
# past runs.
dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date)
dates = sorted({d.execution_date for d in dag_runs})
else:
dates = dag.date_range(start_date=start_date, end_date=end_date)
return dates
Expand Down
20 changes: 20 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,26 @@ def get_dagrun(self, execution_date, session=None):

return dagrun

@provide_session
def get_dagruns_between(self, start_date, end_date, session=None):
"""
Returns the list of dag runs between start_date (inclusive) and end_date (inclusive).
:param start_date: The starting execution date of the DagRun to find.
:param end_date: The ending execution date of the DagRun to find.
:param session:
:return: The list of DagRuns found.
"""
dagruns = (
session.query(DagRun)
.filter(
DagRun.dag_id == self.dag_id,
DagRun.execution_date >= start_date,
DagRun.execution_date <= end_date)
.all())

return dagruns

@provide_session
def _get_latest_execution_date(self, session=None):
return session.query(func.max(DagRun.execution_date)).filter(
Expand Down
50 changes: 49 additions & 1 deletion tests/api/common/experimental/test_mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import unittest
import time
from datetime import datetime
from datetime import datetime, timedelta

from airflow import configuration, models
from airflow.api.common.experimental.mark_tasks import (
Expand All @@ -44,7 +44,12 @@ def setUpClass(cls):
cls.dag1.sync_to_db()
cls.dag2 = dagbag.dags['example_subdag_operator']
cls.dag2.sync_to_db()
cls.dag3 = dagbag.dags['example_trigger_target_dag']
cls.dag3.sync_to_db()
cls.execution_dates = [days_ago(2), days_ago(1)]
start_date3 = cls.dag3.default_args["start_date"]
cls.dag3_execution_dates = [start_date3, start_date3 + timedelta(days=1),
start_date3 + timedelta(days=2)]

def setUp(self):
clear_db_runs()
Expand All @@ -64,6 +69,14 @@ def setUp(self):
dr.dag = self.dag2
dr.verify_integrity()

drs = _create_dagruns(self.dag3,
self.dag3_execution_dates,
state=State.SUCCESS,
run_id_template="manual__{}")
for dr in drs:
dr.dag = self.dag3
dr.verify_integrity()

def tearDown(self):
clear_db_runs()

Expand Down Expand Up @@ -140,6 +153,23 @@ def test_mark_tasks_now(self):
self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
State.SUCCESS, snapshot)

# set one task as FAILED. dag3 has schedule_interval None
snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates)
task = self.dag3.get_task("run_this")
altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1],
upstream=False, downstream=False, future=False,
past=False, state=State.FAILED, commit=True)
# exactly one TaskInstance should have been altered
self.assertEqual(len(altered), 1)
# task should have been marked as failed
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[1]],
State.FAILED, snapshot)
# tasks on other days should be unchanged
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]],
None, snapshot)
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]],
None, snapshot)

def test_mark_downstream(self):
# test downstream
snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
Expand Down Expand Up @@ -179,6 +209,15 @@ def test_mark_tasks_future(self):
self.assertEqual(len(altered), 2)
self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot)

snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates)
task = self.dag3.get_task("run_this")
altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1],
upstream=False, downstream=False, future=True,
past=False, state=State.FAILED, commit=True)
self.assertEqual(len(altered), 2)
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[0]], None, snapshot)
self.verify_state(self.dag3, [task.task_id], self.dag3_execution_dates[1:], State.FAILED, snapshot)

def test_mark_tasks_past(self):
# set one task to success towards end of scheduled dag runs
snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
Expand All @@ -189,6 +228,15 @@ def test_mark_tasks_past(self):
self.assertEqual(len(altered), 2)
self.verify_state(self.dag1, [task.task_id], self.execution_dates, State.SUCCESS, snapshot)

snapshot = TestMarkTasks.snapshot_state(self.dag3, self.dag3_execution_dates)
task = self.dag3.get_task("run_this")
altered = set_state(tasks=[task], execution_date=self.dag3_execution_dates[1],
upstream=False, downstream=False, future=False,
past=True, state=State.FAILED, commit=True)
self.assertEqual(len(altered), 2)
self.verify_state(self.dag3, [task.task_id], self.dag3_execution_dates[:2], State.FAILED, snapshot)
self.verify_state(self.dag3, [task.task_id], [self.dag3_execution_dates[2]], None, snapshot)

def test_mark_tasks_multiple(self):
# set multiple tasks to success
snapshot = TestMarkTasks.snapshot_state(self.dag1, self.execution_dates)
Expand Down

0 comments on commit 271c184

Please sign in to comment.