Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-45 Remove dag parsing in airflow run local #21877

Merged
merged 1 commit into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ def string_lower_type(val):
("--ship-dag",), help="Pickles (serializes) the DAG and ships it to the worker", action="store_true"
)
ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)")
ARG_ERROR_FILE = Arg(("--error-file",), help="File to store task failure error")
ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS)
ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg")
ARG_MAP_INDEX = Arg(('--map-index',), type=int, default=-1, help="Mapped task index")
Expand Down Expand Up @@ -1264,7 +1263,6 @@ class GroupCommand(NamedTuple):
ARG_PICKLE,
ARG_JOB_ID,
ARG_INTERACTIVE,
ARG_ERROR_FILE,
ARG_SHUT_DOWN_LOGGING,
ARG_MAP_INDEX,
),
Expand Down
7 changes: 5 additions & 2 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
get_dag,
get_dag_by_deserialization,
get_dag_by_file_location,
get_dag_by_pickle,
get_dags,
Expand Down Expand Up @@ -258,7 +259,6 @@ def _run_raw_task(args, ti: TaskInstance) -> None:
mark_success=args.mark_success,
job_id=args.job_id,
pool=args.pool,
error_file=args.error_file,
)


Expand Down Expand Up @@ -357,7 +357,10 @@ def task_run(args, dag=None):
print(f'Loading pickle id: {args.pickle}')
dag = get_dag_by_pickle(args.pickle)
elif not dag:
dag = get_dag(args.subdir, args.dag_id)
if args.local:
dag = get_dag_by_deserialization(args.dag_id)
else:
dag = get_dag(args.subdir, args.dag_id)
else:
# Use DAG from parameter
pass
Expand Down
2 changes: 1 addition & 1 deletion airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index)
# TODO: Use simple_ti to improve performance here in the future
ti.refresh_from_db()
ti.handle_failure_with_callback(error=request.msg, test_mode=self.UNIT_TEST_MODE)
ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE)
self.log.info('Executed failure callback for %s in state %s', ti, ti.state)

@provide_session
Expand Down
3 changes: 0 additions & 3 deletions airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def sync(self) -> None:
self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
ti.set_state(State.FAILED)
self.change_state(ti.key, State.FAILED)
ti._run_finished_callback()
continue

task_succeeded = self._run_task(ti)
Expand All @@ -78,12 +77,10 @@ def _run_task(self, ti: TaskInstance) -> bool:
params = self.tasks_params.pop(ti.key, {})
ti._run_raw_task(job_id=ti.job_id, **params)
self.change_state(key, State.SUCCESS)
ti._run_finished_callback()
return True
except Exception as e:
ti.set_state(State.FAILED)
self.change_state(key, State.FAILED)
ti._run_finished_callback(error=e)
self.log.exception("Failed to execute task: %s.", str(e))
return False

Expand Down
2 changes: 1 addition & 1 deletion airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _manage_executor_state(
f"{ti.state}. Was the task killed externally? Info: {info}"
)
self.log.error(msg)
ti.handle_failure_with_callback(error=msg)
ti.handle_failure(error=msg)
continue
if ti.state not in self.STATES_COUNT_AS_RUNNING:
# Don't use ti.task; if this task is mapped, that attribute
Expand Down
73 changes: 26 additions & 47 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __init__(
# terminate multiple times
self.terminating = False

self._state_change_checks = 0

super().__init__(*args, **kwargs)

def _execute(self):
Expand All @@ -84,7 +86,6 @@ def signal_handler(signum, frame):
self.log.error("Received SIGTERM. Terminating subprocesses")
self.task_runner.terminate()
self.handle_task_exit(128 + signum)
return

signal.signal(signal.SIGTERM, signal_handler)

Expand All @@ -106,13 +107,15 @@ def signal_handler(signum, frame):

heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')

# task callback invocation happens either here or in
# self.heartbeat() instead of taskinstance._run_raw_task to
# avoid race conditions
#
# When self.terminating is set to True by heartbeat_callback, this
# loop should not be restarted. Otherwise self.handle_task_exit
# will be invoked and we will end up with duplicated callbacks
# LocalTaskJob should not run callbacks, which are handled by TaskInstance._run_raw_task
# 1, LocalTaskJob does not parse DAG, thus cannot run callbacks
# 2, The run_as_user of LocalTaskJob is likely not same as the TaskInstance._run_raw_task.
# When run_as_user is specified, the process owner of the LocalTaskJob must be sudoable.
# It is not secure to run callbacks with sudoable users.

# If _run_raw_task receives SIGKILL, scheduler will mark it as zombie and invoke callbacks
# If LocalTaskJob receives SIGTERM, LocalTaskJob passes SIGTERM to _run_raw_task
# If the state of task_instance is changed, LocalTaskJob sends SIGTERM to _run_raw_task
while not self.terminating:
# Monitor the task to see if it's done. Wait in a syscall
# (`os.wait`) for as long as possible so we notice the
Expand Down Expand Up @@ -150,26 +153,18 @@ def signal_handler(signum, frame):
self.on_kill()

def handle_task_exit(self, return_code: int) -> None:
"""Handle case where self.task_runner exits by itself or is externally killed"""
"""
Handle case where self.task_runner exits by itself or is externally killed

Dont run any callbacks
"""
# Without setting this, heartbeat may get us
self.terminating = True
self.log.info("Task exited with return code %s", return_code)
self.task_instance.refresh_from_db()

if self.task_instance.state == State.RUNNING:
# This is for a case where the task received a SIGKILL
# while running or the task runner received a sigterm
self.task_instance.handle_failure(error=None)
# We need to check for error file
# in case it failed due to runtime exception/error
error = None
if self.task_instance.state != State.SUCCESS:
error = self.task_runner.deserialize_run_error()
self.task_instance._run_finished_callback(error=error)
if not self.task_instance.test_mode:
if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
self._run_mini_scheduler_on_child_tasks()
self._update_dagrun_state_for_paused_dag()

def on_kill(self):
self.task_runner.terminate()
Expand Down Expand Up @@ -217,19 +212,16 @@ def heartbeat_callback(self, session=None):
dagrun_timeout = ti.task.dag.dagrun_timeout
if dagrun_timeout and execution_time > dagrun_timeout:
self.log.warning("DagRun timed out after %s.", str(execution_time))
self.log.warning(
"State of this instance has been externally set to %s. Terminating instance.", ti.state
)
self.task_runner.terminate()
if ti.state == State.SUCCESS:
error = None
else:
# if ti.state is not set by taskinstance.handle_failure, then
# error file will not be populated and it must be updated by
# external source such as web UI
error = self.task_runner.deserialize_run_error() or "task marked as failed externally"
ti._run_finished_callback(error=error)
self.terminating = True

# potential race condition, the _run_raw_task commits `success` or other state
# but task_runner does not exit right away due to slow process shutdown or any other reasons
# let's do a throttle here, if the above case is true, the handle_task_exit will handle it
if self._state_change_checks >= 1: # defer to next round of heartbeat
self.log.warning(
"State of this instance has been externally set to %s. Terminating instance.", ti.state
)
self.terminating = True
self._state_change_checks += 1

@provide_session
@Sentry.enrich_errors
Expand Down Expand Up @@ -282,19 +274,6 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
)
session.rollback()

@provide_session
def _update_dagrun_state_for_paused_dag(self, session=None):
"""
Checks for paused dags with DagRuns in the running state and
update the DagRun state if possible
"""
dag = self.task_instance.task.dag
if dag.get_is_paused():
dag_run = self.task_instance.get_dagrun(session=session)
if dag_run:
dag_run.dag = dag
dag_run.update_state(session=session, execute_callbacks=True)

@staticmethod
def _enable_task_listeners():
"""
Expand Down
22 changes: 22 additions & 0 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
self.processor_agent: Optional[DagFileProcessorAgent] = None

self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False)
self._paused_dag_without_running_dagruns: Set = set()

if conf.getboolean('smart_sensor', 'use_smart_sensor'):
compatible_sensors = set(
Expand Down Expand Up @@ -764,6 +765,26 @@ def _execute(self) -> None:
self.log.exception("Exception when executing DagFileProcessorAgent.end")
self.log.info("Exited execute loop")

def _update_dag_run_state_for_paused_dags(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this have to move out of LocalTaskJob?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since no callbacks is executed in the LocalTaskJob, we need to rely on the scheduler to check and update the paused dag runs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too happy about this change -- the point of having this in LocalTaskJob is so that the state of a dag run updates quickly after the task finishes.

By moving it to the scheduler we've lost that.

Instead of moving this to the scheduler, could we move it "down" in to the raw task process instead? That or I wonder if there is some way with AIP-44 to add a callback from the LTJ if needed.

Final option: If there are no callbacks we can still safely do this here, so maybe we only need to do this in the scheduler in case of callbacks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point on having the LocalTaskJob to mark the dag run quickly. my concern is that it leaks the responsibility of the scheduler, also there is chance that the LTJ fails before marking the dag run state, which leaves the unhandled dag run states (this could be rare). also, since the dag is paused, it would be a good deal if the dag run state is updated a little bit late.

as we discussed in this email thread: [DISCUSSION] let scheduler heal tasks stuck in queued state, we will need to define the responsibility of each component in airflow (scheduler, executor, LTJ (airflow run --local), airflow run --raw) in terms of the state machine.

we can have a thorough discussion there. let me know your thoughts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quite agree with no need to update pasused tasks state immediately - this makes this part of the changfe far more appealing. Since this change is merged - @ashb - if you stil have some concerns, we can continue discussing here and maybe make a follow-up change in case we think of a scenario where it might be problematic.

try:
paused_dag_ids = DagModel.get_all_paused_dag_ids()
for dag_id in paused_dag_ids:
pingzh marked this conversation as resolved.
Show resolved Hide resolved
if dag_id in self._paused_dag_without_running_dagruns:
continue
Comment on lines +772 to +773
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dag_ids are never removed from this set. If I unpause the DAG after this has run and then later re-pause it, then this method skips it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. let me open a PR.


dag = SerializedDagModel.get_dag(dag_id)
if dag is None:
continue
dag_runs = DagRun.find(dag_id=dag_id, state=State.RUNNING)
for dag_run in dag_runs:
dag_run.dag = dag
_, callback_to_run = dag_run.update_state(execute_callbacks=False)
if callback_to_run:
self._send_dag_callbacks_to_processor(dag, callback_to_run)
self._paused_dag_without_running_dagruns.add(dag_id)
except Exception as e: # should not fail the scheduler
self.log.exception('Failed to update dag run state for paused dags due to %s', str(e))

def _run_scheduler_loop(self) -> None:
"""
The actual scheduler loop. The main steps in the loop are:
Expand Down Expand Up @@ -809,6 +830,7 @@ def _run_scheduler_loop(self) -> None:
conf.getfloat('scheduler', 'zombie_detection_interval', fallback=10.0),
self._find_zombies,
)
timers.call_regular_interval(60.0, self._update_dag_run_state_for_paused_dags)

for loop_count in itertools.count(start=1):
with Stats.timer() as timer:
Expand Down
9 changes: 9 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2762,6 +2762,15 @@ def get_dagmodel(dag_id, session=NEW_SESSION):
def get_current(cls, dag_id, session=NEW_SESSION):
return session.query(cls).filter(cls.dag_id == dag_id).first()

@staticmethod
@provide_session
def get_all_paused_dag_ids(session: Session = NEW_SESSION) -> Set[str]:
"""Get a set of paused DAG ids"""
paused_dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_paused == expression.true()).all()

paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids}
return paused_dag_ids

@provide_session
def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
return get_last_dagrun(
Expand Down
8 changes: 8 additions & 0 deletions airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,14 @@ def has_dag(cls, dag_id: str, session: Session = None) -> bool:
"""
return session.query(literal(True)).filter(cls.dag_id == dag_id).first() is not None

@classmethod
@provide_session
def get_dag(cls, dag_id: str, session: Session = None) -> Optional['SerializedDAG']:
row = cls.get(dag_id, session=session)
if row:
return row.dag
return None

@classmethod
@provide_session
def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagModel']:
Expand Down
Loading