Skip to content

Commit

Permalink
Implement get_dagrun on TaskInstancePydantic (#38295)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent <[email protected]>
  • Loading branch information
dstandish and vincbeck authored Mar 20, 2024
1 parent 0eb1405 commit b3f54e0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
1 change: 1 addition & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _initialize_map() -> dict[str, Callable]:
SerializedDagModel.get_serialized_dag,
TaskInstance._check_and_change_state_before_execution,
TaskInstance.get_task_instance,
TaskInstance._get_dagrun,
TaskInstance.fetch_handle_failure_context,
TaskInstance.save_to_db,
TaskInstance._schedule_downstream_tasks,
Expand Down
13 changes: 9 additions & 4 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,14 @@ def ready_for_retry(self) -> bool:
"""Check on whether the task instance is in the right state and timeframe to be retried."""
return self.state == TaskInstanceState.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()

@staticmethod
@internal_api_call
def _get_dagrun(dag_id, run_id, session) -> DagRun:
from airflow.models.dagrun import DagRun # Avoid circular import

dr = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
return dr

@provide_session
def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
"""
Expand All @@ -2131,13 +2139,10 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
self.dag_run.dag = self.task.dag
return self.dag_run

from airflow.models.dagrun import DagRun # Avoid circular import

dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one()
dr = self._get_dagrun(self.dag_id, self.run_id, session)
if getattr(self, "task", None) is not None:
if TYPE_CHECKING:
assert self.task

dr.dag = self.task.dag
# Record it in the instance for next time. This means that `self.execution_date` will work correctly
set_committed_value(self, "dag_run", dr)
Expand Down
6 changes: 2 additions & 4 deletions airflow/serialization/pydantic/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,9 @@ def get_dagrun(self, session: Session | None = None) -> DagRunPydantic:
:param session: SQLAlchemy ORM Session
TODO: make it works for AIP-44
:return: Pydantic serialized version of DaGrun
:return: Pydantic serialized version of DagRun
"""
raise NotImplementedError()
return TaskInstance._get_dagrun(dag_id=self.dag_id, run_id=self.run_id, session=session)

def _execute_task(self, context, task_orig):
"""
Expand Down

0 comments on commit b3f54e0

Please sign in to comment.