Skip to content

Commit

Permalink
Scheduler to handle incrementing of try_number (apache#39336)
Browse files Browse the repository at this point in the history
Previously, there was a lot of bad stuff happening around try_number.

We incremented it when task started running. And because of that, we had this logic to return "_try_number + 1" when task not running. But this gave the "right" try number before it ran, and the wrong number after it ran. And, since it was naively incremented when task starts running -- i.e. without regard to why it is running -- we decremented it when deferring or exiting on a reschedule.

What I do here is try to remove all of that stuff:

no more private _try_number attr
no more getter logic
no more decrementing
no more incrementing as part of task execution
Now what we do is increment only when the task is set to scheduled and only when it's not coming out of deferral or "up_for_reschedule". So the try_number will be more stable. It will not change throughout the course of task execution. The only time it will be incremented is when there's legitimately a new try.

One consequence of this is that try number will no longer be incremented if you run either airlfow tasks run or ti.run() in isolation. But because airflow assumes that all tasks runs are scheduled by the scheduler, I do not regard this to be a breaking change.

If user code or provider code has implemented hacks to get the "right" try_number when looking at it at the wrong time (because previously it gave the wrong answer), unfortunately that code will just have to be patched. There are only two cases I know of in the providers codebase -- openlineage listener, and dbt openlineage.

As a courtesy for backcompat we also add property _try_number which is just a proxy for try_number, so you'll still be able to access this attr. But, it will not behave the same as it did before.

---------

Co-authored-by: Jed Cunningham <[email protected]>
  • Loading branch information
2 people authored and RodrigoGanancia committed May 10, 2024
1 parent 8043844 commit 84f63b4
Show file tree
Hide file tree
Showing 40 changed files with 517 additions and 376 deletions.
4 changes: 0 additions & 4 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,6 @@ def set_state(
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += session.scalars(qry_sub_dag.with_for_update()).all()
for task_instance in tis_altered:
# The try_number was decremented when setting to up_for_reschedule and deferred.
# Increment it back when changing the state again
if task_instance.state in (TaskInstanceState.DEFERRED, TaskInstanceState.UP_FOR_RESCHEDULE):
task_instance._try_number += 1
task_instance.set_state(state, session=session)
session.flush()
else:
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/schemas/task_instance_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Meta:
end_date = auto_field()
duration = auto_field()
state = TaskInstanceStateField()
_try_number = auto_field(data_key="try_number")
try_number = auto_field()
max_tries = auto_field()
task_display_name = fields.String(attribute="task_display_name", dump_only=True)
hostname = auto_field()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DecreasingPriorityStrategy(PriorityWeightStrategy):
"""A priority weight strategy that decreases the priority weight with each attempt of the DAG task."""

def get_weight(self, ti: TaskInstance):
return max(3 - ti._try_number + 1, 1)
return max(3 - ti.try_number + 1, 1)


class DecreasingPriorityWeightStrategyPlugin(AirflowPlugin):
Expand Down
30 changes: 26 additions & 4 deletions airflow/jobs/backfill_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import attr
import pendulum
from sqlalchemy import select, tuple_, update
from sqlalchemy import case, or_, select, tuple_, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import make_transient
from tabulate import tabulate
Expand Down Expand Up @@ -245,7 +245,16 @@ def _update_counters(self, ti_status: _DagRunTaskStatus, session: Session) -> No
session.execute(
update(TI)
.where(filter_for_tis)
.values(state=TaskInstanceState.SCHEDULED)
.values(
state=TaskInstanceState.SCHEDULED,
try_number=case(
(
or_(TI.state.is_(None), TI.state != TaskInstanceState.UP_FOR_RESCHEDULE),
TI.try_number + 1,
),
else_=TI.try_number,
),
)
.execution_options(synchronize_session=False)
)
session.flush()
Expand Down Expand Up @@ -425,6 +434,8 @@ def _task_instances_for_dag_run(
try:
for ti in dag_run.get_task_instances(session=session):
if ti in schedulable_tis:
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
ti.try_number += 1
ti.set_state(TaskInstanceState.SCHEDULED)
if ti.state != TaskInstanceState.REMOVED:
tasks_to_run[ti.key] = ti
Expand Down Expand Up @@ -515,6 +526,7 @@ def _per_task_process(key, ti: TaskInstance, session):
if key in ti_status.running:
ti_status.running.pop(key)
# Reset the failed task in backfill to scheduled state
ti.try_number += 1
ti.set_state(TaskInstanceState.SCHEDULED, session=session)
if ti.dag_run not in ti_status.active_runs:
ti_status.active_runs.add(ti.dag_run)
Expand Down Expand Up @@ -552,6 +564,14 @@ def _per_task_process(key, ti: TaskInstance, session):
else:
self.log.debug("Sending %s to executor", ti)
# Skip scheduled state, we are executing immediately
if ti.state in (TaskInstanceState.UP_FOR_RETRY, None):
# i am not sure why this is necessary.
# seemingly a quirk of backfill runner.
# it should be handled elsewhere i think.
# seems the leaf tasks are set SCHEDULED but others not.
# but i am not going to look too closely since we need
# to nuke the current backfill approach anyway.
ti.try_number += 1
ti.state = TaskInstanceState.QUEUED
ti.queued_by_job_id = self.job.id
ti.queued_dttm = timezone.utcnow()
Expand Down Expand Up @@ -695,7 +715,9 @@ def _per_task_process(key, ti: TaskInstance, session):
self.log.debug(e)

perform_heartbeat(
job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True
job=self.job,
heartbeat_callback=self.heartbeat_callback,
only_if_necessary=True,
)
# execute the tasks in the queue
executor.heartbeat()
Expand Down Expand Up @@ -725,6 +747,7 @@ def to_keep(key: TaskInstanceKey) -> bool:
ti_status.to_run.update({ti.key: ti for ti in new_mapped_tis})

for new_ti in new_mapped_tis:
new_ti.try_number += 1
new_ti.set_state(TaskInstanceState.SCHEDULED, session=session)

# Set state to failed for running TIs that are set up for retry if disable-retry flag is set
Expand Down Expand Up @@ -930,7 +953,6 @@ def _execute(self, session: Session = NEW_SESSION) -> None:
"combination. Please adjust backfill dates or wait for this DagRun to finish.",
)
return
# picklin'
pickle_id = None

executor_class, _ = ExecutorLoader.import_default_executor_cls()
Expand Down
2 changes: 2 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2948,6 +2948,8 @@ def add_logger_if_needed(ti: TaskInstance):
session.expire_all()
schedulable_tis, _ = dr.update_state(session=session)
for s in schedulable_tis:
if s.state != TaskInstanceState.UP_FOR_RESCHEDULE:
s.try_number += 1
s.state = TaskInstanceState.SCHEDULED
session.commit()
# triggerer may mark tasks scheduled so we read from DB
Expand Down
16 changes: 13 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates
from sqlalchemy.sql.expression import false, select, true
from sqlalchemy.sql.expression import case, false, select, true

from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
Expand Down Expand Up @@ -1545,7 +1545,8 @@ def schedule_tis(
and not ti.task.on_success_callback
and not ti.task.outlets
):
ti._try_number += 1
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
ti.try_number += 1
ti.defer_task(
defer=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method),
session=session,
Expand All @@ -1567,7 +1568,16 @@ def schedule_tis(
TI.run_id == self.run_id,
tuple_in_condition((TI.task_id, TI.map_index), schedulable_ti_ids_chunk),
)
.values(state=TaskInstanceState.SCHEDULED)
.values(
state=TaskInstanceState.SCHEDULED,
try_number=case(
(
or_(TI.state.is_(None), TI.state != TaskInstanceState.UP_FOR_RESCHEDULE),
TI.try_number + 1,
),
else_=TI.try_number,
),
)
.execution_options(synchronize_session=False)
).rowcount

Expand Down
Loading

0 comments on commit 84f63b4

Please sign in to comment.