From 5f6c2538728188f8efdf0717a4f6d0f11f8f2e72 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Mon, 8 Jul 2024 00:13:20 +0200 Subject: [PATCH] Remove default `create_dataset` method. Add section in experimental lineage docs. Signed-off-by: Jakub Dardzinski --- airflow/datasets/__init__.py | 5 -- airflow/lineage/hook.py | 70 +++++++++++++------ airflow/plugins_manager.py | 27 +++++++ airflow/provider.yaml.schema.json | 7 -- airflow/providers_manager.py | 42 ++++------- .../administration-and-deployment/lineage.rst | 42 +++++++++++ tests/lineage/test_hook.py | 50 +++++++------ tests/test_utils/mock_plugins.py | 1 + 8 files changed, 162 insertions(+), 82 deletions(-) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 11eca7f62fb1f1..7e26df496b9dbe 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -44,11 +44,6 @@ def normalize_noop(parts: SplitResult) -> SplitResult: return parts -def create_dataset(uri: str) -> Dataset: - """Create a dataset object from a dataset URI.""" - return Dataset(uri=uri) - - def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: if scheme == "file": return normalize_noop diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index 9510f1b2d4a111..cc60b843bea2ea 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -21,7 +21,7 @@ import attr -from airflow.datasets import Dataset, create_dataset +from airflow.datasets import Dataset from airflow.hooks.base import BaseHook from airflow.io.store import ObjectStore from airflow.providers_manager import ProvidersManager @@ -53,36 +53,61 @@ def __init__(self, **kwargs): self.inputs: list[tuple[Dataset, LineageContext]] = [] self.outputs: list[tuple[Dataset, LineageContext]] = [] - @staticmethod - def create_dataset(dataset_kwargs: dict) -> Dataset: + def create_dataset( + self, scheme: str | None, uri: str | None, dataset_kwargs: dict | None, dataset_extra: dict | None + ) -> Dataset | None: """Create a Dataset instance from the given dataset kwargs.""" - if "uri" in dataset_kwargs: + if uri: # Fallback to default factory using the provided URI - return create_dataset(dataset_kwargs["uri"]) + return Dataset(uri=uri, extra=dataset_extra) - scheme: str = dataset_kwargs.pop("scheme", None) + # scheme: str = dataset_kwargs.pop("scheme", None) if not scheme: - raise ValueError( + self.log.debug( "Missing required parameter: either 'uri' or 'scheme' must be provided to create a Dataset." ) + return None dataset_factory = ProvidersManager().dataset_factories.get(scheme) if not dataset_factory: - raise ValueError( - f"Unsupported scheme: '{scheme}'. Please provide a valid URI to create a Dataset." - ) + self.log.debug("Unsupported scheme: %s. Please provide a valid URI to create a Dataset.", scheme) + return None - return dataset_factory(**dataset_kwargs) + try: + return dataset_factory(**(dataset_kwargs or {}), extra=dataset_extra) + except Exception as e: + self.log.warning("Failed to create dataset. Skipping. Error: %s", e) + return None - def add_input_dataset(self, dataset_kwargs: dict, hook: LineageContext): + def add_input_dataset( + self, + hook: LineageContext, + scheme: str, + uri: str | None = None, + dataset_kwargs: dict | None = None, + dataset_extra: dict | None = None, + ): """Add the input dataset and its corresponding hook execution context to the collector.""" - dataset = self.create_dataset(dataset_kwargs) - self.inputs.append((dataset, hook)) + dataset = self.create_dataset( + scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, dataset_extra=dataset_extra + ) + if dataset: + self.inputs.append((dataset, hook)) - def add_output_dataset(self, dataset_kwargs: dict, hook: LineageContext): + def add_output_dataset( + self, + hook: LineageContext, + scheme: str, + uri: str | None = None, + dataset_kwargs: dict | None = None, + dataset_extra: dict | None = None, + ): """Add the output dataset and its corresponding hook execution context to the collector.""" - dataset = self.create_dataset(dataset_kwargs) - self.outputs.append((dataset, hook)) + dataset = self.create_dataset( + scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, dataset_extra=dataset_extra + ) + if dataset: + self.outputs.append((dataset, hook)) @property def collected_datasets(self) -> HookLineage: @@ -112,7 +137,9 @@ def add_output_dataset(self, *_): def collected_datasets( self, ) -> HookLineage: - self.log.warning("You should not call this as there's no reader.") + self.log.warning( + "Data lineage tracking is disabled. Register a hook lineage reader to start tracking hook lineage." + ) return HookLineage([], []) @@ -132,8 +159,11 @@ def get_hook_lineage_collector() -> HookLineageCollector: """Get singleton lineage collector.""" global _hook_lineage_collector if not _hook_lineage_collector: - # is there a better why how to use noop? - if ProvidersManager().hook_lineage_readers: + from airflow import plugins_manager + + plugins_manager.initialize_hook_lineage_readers_plugins() + print("DUPA", plugins_manager.hook_lineage_reader_classes) + if plugins_manager.hook_lineage_reader_classes: _hook_lineage_collector = HookLineageCollector() else: _hook_lineage_collector = NoOpCollector() diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 63b3dbd80d47a5..76d72a45850093 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -27,6 +27,7 @@ import os import sys import types +from cgitb import Hook from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable @@ -41,6 +42,8 @@ from airflow.utils.module_loading import import_string, qualname if TYPE_CHECKING: + from airflow.lineage.hook import HookLineageReader + try: import importlib_metadata as metadata except ImportError: @@ -75,6 +78,7 @@ registered_operator_link_classes: dict[str, type] | None = None registered_ti_dep_classes: dict[str, type] | None = None timetable_classes: dict[str, type[Timetable]] | None = None +hook_lineage_reader_classes: list[type[Hook]] | None = None priority_weight_strategy_classes: dict[str, type[PriorityWeightStrategy]] | None = None """ Mapping of class names to class of OperatorLinks registered by plugins. @@ -176,8 +180,12 @@ class AirflowPlugin: # A list of timetable classes that can be used for DAG scheduling. timetables: list[type[Timetable]] = [] + # A list of listeners that can be used for tracking task and DAG states. listeners: list[ModuleType | object] = [] + # A list of hook lineage reader classes that can be used for reading lineage information from a hook. + hook_lineage_readers: list[type[HookLineageReader]] = [] + # A list of priority weight strategy classes that can be used for calculating tasks weight priority. priority_weight_strategies: list[type[PriorityWeightStrategy]] = [] @@ -483,6 +491,25 @@ def initialize_timetables_plugins(): } +def initialize_hook_lineage_readers_plugins(): + """Collect hook lineage reader classes registered by plugins.""" + global hook_lineage_reader_classes + + if hook_lineage_reader_classes is not None: + return + + ensure_plugins_loaded() + + if plugins is None: + raise AirflowPluginException("Can't load plugins.") + + log.debug("Initialize hook lineage readers plugins") + + hook_lineage_reader_classes = [] + for plugin in plugins: + hook_lineage_reader_classes.extend(plugin.hook_lineage_readers) + + def integrate_executor_plugins() -> None: """Integrate executor plugins to the context.""" global plugins diff --git a/airflow/provider.yaml.schema.json b/airflow/provider.yaml.schema.json index 0498ad916ab837..adbca7846d19e8 100644 --- a/airflow/provider.yaml.schema.json +++ b/airflow/provider.yaml.schema.json @@ -220,13 +220,6 @@ } } }, - "hook-lineage-readers": { - "type": "array", - "description": "Hook lineage readers", - "items": { - "type": "string" - } - }, "transfers": { "type": "array", "items": { diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index eeaaa3e083c7c4..9e9dd4d573ddd7 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -428,7 +428,6 @@ def __init__(self): self._fs_set: set[str] = set() self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {} self._dataset_factories: dict[str, Callable[..., Dataset]] = {} - self._hook_lineage_readers: set[str] = set() self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment] # keeps mapping between connection_types and hook class, package they come from self._hook_provider_dict: dict[str, HookClassProvider] = {} @@ -526,17 +525,11 @@ def initialize_providers_filesystems(self): self._discover_filesystems() @provider_info_cache("dataset_uris") - def initialize_providers_dataset_uri_handlers(self): + def initialize_providers_dataset_uri_handlers_and_factories(self): """Lazy initialization of provider dataset URI handlers.""" self.initialize_providers_list() self._discover_dataset_uri_handlers_and_factories() - @provider_info_cache("hook_lineage_readers") - def initialize_providers_hook_lineage_readers(self): - """Lazy initialization of providers hook lineage readers.""" - self.initialize_providers_list() - self._discover_hook_lineage_readers() - @provider_info_cache("hook_lineage_writers") @provider_info_cache("taskflow_decorators") def initialize_providers_taskflow_decorator(self): @@ -574,7 +567,7 @@ def initialize_providers_notifications(self): self.initialize_providers_list() self._discover_notifications() - @provider_info_cache(cache_name="auth_managers") + @provider_info_cache("auth_managers") def initialize_providers_auth_managers(self): """Lazy initialization of providers notifications information.""" self.initialize_providers_list() @@ -889,34 +882,28 @@ def _discover_filesystems(self) -> None: self._fs_set = set(sorted(self._fs_set)) def _discover_dataset_uri_handlers_and_factories(self) -> None: - from airflow.datasets import create_dataset, normalize_noop + from airflow.datasets import normalize_noop for provider_package, provider in self._provider_dict.items(): for handler_info in provider.data.get("dataset-uris", []): try: schemes = handler_info["schemes"] handler_path = handler_info["handler"] - factory_path = handler_info["factory"] except KeyError: continue if handler_path is None: handler = normalize_noop - if factory_path is None: - factory = create_dataset - elif not (handler := _correctness_check(provider_package, handler_path, provider)) or not ( - factory := _correctness_check(provider_package, factory_path, provider) - ): + elif not (handler := _correctness_check(provider_package, handler_path, provider)): continue self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes) + factory_path = handler_info.get("factory") + if not ( + factory_path is not None + and (factory := _correctness_check(provider_package, factory_path, provider)) + ): + continue self._dataset_factories.update((scheme, factory) for scheme in schemes) - def _discover_hook_lineage_readers(self) -> None: - for provider_package, provider in self._provider_dict.items(): - for hook_lineage_reader in provider.data.get("hook-lineage-readers", []): - if _correctness_check(provider_package, hook_lineage_reader, provider): - self._hook_lineage_readers.add(hook_lineage_reader) - self._fs_set = set(sorted(self._fs_set)) - def _discover_taskflow_decorators(self) -> None: for name, info in self._provider_dict.items(): for taskflow_decorator in info.data.get("task-decorators", []): @@ -1314,19 +1301,14 @@ def filesystem_module_names(self) -> list[str]: @property def dataset_factories(self) -> dict[str, Callable[..., Dataset]]: - self.initialize_providers_dataset_uri_handlers() + self.initialize_providers_dataset_uri_handlers_and_factories() return self._dataset_factories @property def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]: - self.initialize_providers_dataset_uri_handlers() + self.initialize_providers_dataset_uri_handlers_and_factories() return self._dataset_uri_handlers - @property - def hook_lineage_readers(self) -> list[str]: - self.initialize_providers_hook_lineage_readers() - return sorted(self._hook_lineage_readers) - @property def provider_configs(self) -> list[tuple[str, dict[str, Any]]]: self.initialize_providers_configuration() diff --git a/docs/apache-airflow/administration-and-deployment/lineage.rst b/docs/apache-airflow/administration-and-deployment/lineage.rst index 7b967d3f10e61a..4ea210ff955ea1 100644 --- a/docs/apache-airflow/administration-and-deployment/lineage.rst +++ b/docs/apache-airflow/administration-and-deployment/lineage.rst @@ -89,6 +89,48 @@ has outlets defined (e.g. by using ``add_outlets(..)`` or has out of the box sup .. _precedence: https://docs.python.org/3/reference/expressions.html +Hook Lineage +------------ + +Airflow provides a powerful feature for tracking data lineage not only between tasks but also from hooks used within those tasks. +This functionality helps you understand how data flows throughout your Airflow pipelines. + +A global instance of ``HookLineageCollector`` serves as the central hub for collecting lineage information. +Hooks can send details about datasets they interact with to this collector. +The collector then uses this data to construct AIP-60 compliant Datasets, a standard format for describing datasets. + +.. code-block:: python + + from airflow.lineage.hook_lineage import get_hook_lineage_collector + + + class CustomHook(BaseHook): + def run(self): + # run actual code + collector = get_hook_lineage_collector() + collector.add_input_dataset(self, dataset_kwargs={"scheme": "file", "path": "/tmp/in"}) + collector.add_output_dataset(self, dataset_kwargs={"scheme": "file", "path": "/tmp/out"}) + +Lineage data collected by the ``HookLineageCollector`` can be accessed using an instance of ``HookLineageReader``. + +.. code-block:: python + + from airflow.lineage.hook_lineage import HookLineageReader + from airflow.plugins_manager import AirflowPlugin + + + class CustomHookLineageReader(HookLineageReader): + def get_inputs(self): + return self.lineage_collector.collected_datasets.inputs + + + class HookLineageCollectionPlugin(AirflowPlugin): + name = "HookLineageCollectionPlugin" + hook_lineage_readers = [CustomHookLineageReader] + +If no ``HookLineageReader`` is registered within Airflow, a default ``NoOpCollector`` is used instead. +This collector does not create AIP-60 compliant datasets or collect lineage information. + Lineage Backend --------------- diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py index 876b6e04109cda..88d8e297cea611 100644 --- a/tests/lineage/test_hook.py +++ b/tests/lineage/test_hook.py @@ -21,10 +21,18 @@ import pytest +from airflow import plugins_manager from airflow.datasets import Dataset from airflow.hooks.base import BaseHook from airflow.lineage import hook -from airflow.lineage.hook import HookLineage, HookLineageCollector, NoOpCollector, get_hook_lineage_collector +from airflow.lineage.hook import ( + HookLineage, + HookLineageCollector, + HookLineageReader, + NoOpCollector, + get_hook_lineage_collector, +) +from tests.test_utils.mock_plugins import mock_plugin_manager class TestHookLineageCollector: @@ -34,27 +42,27 @@ def test_are_datasets_collected(self): assert lineage_collector.collected_datasets == HookLineage() input_hook = BaseHook() output_hook = BaseHook() - lineage_collector.add_input_dataset({"uri": "s3://in_bucket/file"}, input_hook) + lineage_collector.add_input_dataset(input_hook, {"uri": "s3://in_bucket/file"}) lineage_collector.add_output_dataset( - {"uri": "postgres://example.com:5432/database/default/table"}, output_hook + output_hook, {"uri": "postgres://example.com:5432/database/default/table"} ) assert lineage_collector.collected_datasets == HookLineage( [(Dataset("s3://in_bucket/file"), input_hook)], [(Dataset("postgres://example.com:5432/database/default/table"), output_hook)], ) - @patch("airflow.lineage.hook.create_dataset") - def test_add_input_dataset(self, mock_create_dataset): + @patch("airflow.lineage.hook.Dataset") + def test_add_input_dataset(self, mock_dataset): collector = HookLineageCollector() - mock_dataset = MagicMock(spec=Dataset) - mock_create_dataset.return_value = mock_dataset + dataset = MagicMock(spec=Dataset) + mock_dataset.return_value = dataset dataset_kwargs = {"uri": "test_uri"} hook = MagicMock() - collector.add_input_dataset(dataset_kwargs, hook) + collector.add_input_dataset(hook, dataset_kwargs) - assert collector.inputs == [(mock_dataset, hook)] - mock_create_dataset.assert_called_once_with("test_uri") + assert collector.inputs == [(dataset, hook)] + mock_dataset.assert_called_once_with(uri="test_uri") @patch("airflow.lineage.hook.ProvidersManager") def test_create_dataset(self, mock_providers_manager): @@ -72,8 +80,8 @@ def create_dataset(arg1, arg2="default"): def test_collected_datasets(self): collector = HookLineageCollector() - inputs = [(MagicMock(spec=Dataset), MagicMock())] - outputs = [(MagicMock(spec=Dataset), MagicMock())] + inputs = [(MagicMock()), MagicMock(spec=Dataset)] + outputs = [(MagicMock()), MagicMock(spec=Dataset)] collector.inputs = inputs collector.outputs = outputs @@ -89,6 +97,11 @@ def test_has_collected(self): assert collector.has_collected +class FakePlugin(plugins_manager.AirflowPlugin): + name = "FakePluginHavingHookLineageCollector" + hook_lineage_readers = [HookLineageReader] + + @pytest.mark.parametrize( "has_readers, expected_class", [ @@ -96,13 +109,10 @@ def test_has_collected(self): (False, NoOpCollector), ], ) -@patch("airflow.lineage.hook.ProvidersManager") -def test_get_hook_lineage_collector(mock_providers_manager, has_readers, expected_class): +def test_get_hook_lineage_collector(has_readers, expected_class): # reset global variable hook._hook_lineage_collector = None - if has_readers: - mock_providers_manager.return_value.hook_lineage_readers = [MagicMock()] - else: - mock_providers_manager.return_value.hook_lineage_readers = [] - assert isinstance(get_hook_lineage_collector(), expected_class) - assert get_hook_lineage_collector() is get_hook_lineage_collector() + plugins = [FakePlugin()] if has_readers else [] + with mock_plugin_manager(plugins=plugins): + assert isinstance(get_hook_lineage_collector(), expected_class) + assert get_hook_lineage_collector() is get_hook_lineage_collector() diff --git a/tests/test_utils/mock_plugins.py b/tests/test_utils/mock_plugins.py index 3242f36159877e..3ab1acb730b9c0 100644 --- a/tests/test_utils/mock_plugins.py +++ b/tests/test_utils/mock_plugins.py @@ -33,6 +33,7 @@ "operator_extra_links", "registered_operator_link_classes", "timetable_classes", + "hook_lineage_reader_classes", ]