Skip to content

Commit

Permalink
Implement context accessor for DatasetEvent extra (#38481)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Mar 29, 2024
1 parent a2f5307 commit fce3a58
Show file tree
Hide file tree
Showing 14 changed files with 228 additions and 68 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,11 @@ repos:
entry: ./scripts/ci/pre_commit/pre_commit_sync_init_decorator.py
pass_filenames: false
files: ^airflow/models/dag\.py$|^airflow/(?:decorators|utils)/task_group\.py$
- id: check-template-context-variable-in-sync
name: Check all template context variable references are in sync
language: python
entry: ./scripts/ci/pre_commit/pre_commit_template_context_key_sync.py
files: ^airflow/models/taskinstance\.py$|^airflow/utils/context\.pyi?$|^docs/apache-airflow/templates-ref\.rst$
- id: check-base-operator-usage
language: pygrep
name: Check BaseOperator core imports
Expand Down
11 changes: 9 additions & 2 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N
return ProvidersManager().dataset_uri_handlers.get(scheme)


def _sanitize_uri(uri: str) -> str:
def sanitize_uri(uri: str) -> str:
"""Sanitize a dataset URI.
This checks for URI validity, and normalizes the URI if needed. A fully
normalized URI is returned.
:meta private:
"""
if not uri:
raise ValueError("Dataset URI cannot be empty")
if uri.isspace():
Expand Down Expand Up @@ -110,7 +117,7 @@ class Dataset(os.PathLike, BaseDatasetEventInput):
"""A representation of data dependencies between workflows."""

uri: str = attr.field(
converter=_sanitize_uri,
converter=sanitize_uri,
validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
)
extra: dict[str, Any] | None = None
Expand Down
14 changes: 11 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,13 @@
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.utils import timezone
from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge
from airflow.utils.context import (
ConnectionAccessor,
Context,
DatasetEventAccessors,
VariableAccessor,
context_merge,
)
from airflow.utils.email import send_email
from airflow.utils.helpers import prune_dict, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -766,6 +772,7 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti
"dag_run": dag_run,
"data_interval_end": timezone.coerce_datetime(data_interval.end),
"data_interval_start": timezone.coerce_datetime(data_interval.start),
"dataset_events": DatasetEventAccessors(),
"ds": ds,
"ds_nodash": ds_nodash,
"execution_date": logical_date,
Expand Down Expand Up @@ -2569,7 +2576,7 @@ def _run_raw_task(
session.add(Log(self.state, self))
session.merge(self).task = self.task
if self.state == TaskInstanceState.SUCCESS:
self._register_dataset_changes(session=session)
self._register_dataset_changes(events=context["dataset_events"], session=session)

session.commit()
if self.state == TaskInstanceState.SUCCESS:
Expand All @@ -2579,7 +2586,7 @@ def _run_raw_task(

return None

def _register_dataset_changes(self, *, session: Session) -> None:
def _register_dataset_changes(self, *, events: DatasetEventAccessors, session: Session) -> None:
if TYPE_CHECKING:
assert self.task

Expand All @@ -2590,6 +2597,7 @@ def _register_dataset_changes(self, *, session: Session) -> None:
dataset_manager.register_dataset_change(
task_instance=self,
dataset=obj,
extra=events[obj].extra,
session=session,
)

Expand Down
34 changes: 34 additions & 0 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
ValuesView,
)

import attrs
import lazy_object_proxy

from airflow.datasets import Dataset, sanitize_uri
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.utils.types import NOTSET

Expand All @@ -54,6 +56,7 @@
"dag_run",
"data_interval_end",
"data_interval_start",
"dataset_events",
"ds",
"ds_nodash",
"execution_date",
Expand Down Expand Up @@ -146,6 +149,37 @@ def get(self, key: str, default_conn: Any = None) -> Any:
return default_conn


@attrs.define()
class DatasetEventAccessor:
"""Wrapper to access a DatasetEvent instance in template."""

extra: dict[str, Any]


class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]):
"""Lazy mapping of dataset event accessors."""

def __init__(self) -> None:
self._dict: dict[str, DatasetEventAccessor] = {}

def __iter__(self) -> Iterator[str]:
return iter(self._dict)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor:
if isinstance(key, str):
uri = sanitize_uri(key)
elif isinstance(key, Dataset):
uri = key.uri
else:
return NotImplemented
if uri not in self._dict:
self._dict[uri] = DatasetEventAccessor({})
return self._dict[uri]


class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
"""Warn for usage of deprecated context variables in a task."""

Expand Down
12 changes: 11 additions & 1 deletion airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
# declare "these are defined, but don't error if others are accessed" someday.
from __future__ import annotations

from typing import Any, Collection, Container, Iterable, Mapping, overload
from typing import Any, Collection, Container, Iterable, Iterator, Mapping, overload

from pendulum import DateTime

from airflow.configuration import AirflowConfigParser
from airflow.datasets import Dataset
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
Expand All @@ -55,6 +56,14 @@ class VariableAccessor:
class ConnectionAccessor:
def get(self, key: str, default_conn: Any = None) -> Any: ...

class DatasetEventAccessor:
extra: dict[str, Any]

class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]):
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor: ...

# NOTE: Please keep this in sync with the following:
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
# * Table in docs/apache-airflow/templates-ref.rst
Expand All @@ -65,6 +74,7 @@ class Context(TypedDict, total=False):
dag_run: DagRun | DagRunPydantic
data_interval_end: DateTime
data_interval_start: DateTime
dataset_events: DatasetEventAccessors
ds: str
ds_nodash: str
exception: BaseException | str | None
Expand Down
2 changes: 2 additions & 0 deletions contributing-docs/08_static_code_checks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-system-tests-tocs | Check that system tests is properly added | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-template-context-variable-in-sync | Check all template context variable references are in sync | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-tests-in-the-right-folders | Check if tests are in the right folders | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-tests-unittest-testcase | Check that unit tests do not inherit from unittest.TestCase | |
Expand Down
Loading

0 comments on commit fce3a58

Please sign in to comment.