From 6cc446a3f105d3e798c4dbad6f2fa2b1c71662e1 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Tue, 16 Jul 2024 01:04:39 +0200 Subject: [PATCH] AIP-62: Add HookLineageCollector (#40335) * Add HookLineageCollector that during task execution should register and hold lineage sent from hooks. Add HookLineageReader that defines whether HookLineageCollector should be enabled to process lineage sent from hooks. Add Dataset factories to make sure Datasets registered with HookLineageCollector is AIP-60 compliant. Signed-off-by: Jakub Dardzinski * Remove default `create_dataset` method. Add section in experimental lineage docs. Signed-off-by: Jakub Dardzinski --------- Signed-off-by: Jakub Dardzinski --- airflow/lineage/hook.py | 181 ++++++++++++++++++ airflow/plugins_manager.py | 27 +++ airflow/provider.yaml.schema.json | 4 + airflow/providers_manager.py | 23 ++- .../administration-and-deployment/lineage.rst | 43 +++++ tests/lineage/test_hook.py | 120 ++++++++++++ tests/test_utils/mock_plugins.py | 1 + 7 files changed, 395 insertions(+), 4 deletions(-) create mode 100644 airflow/lineage/hook.py create mode 100644 tests/lineage/test_hook.py diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py new file mode 100644 index 00000000000000..70893516bd9bdf --- /dev/null +++ b/airflow/lineage/hook.py @@ -0,0 +1,181 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Union + +import attr + +from airflow.datasets import Dataset +from airflow.hooks.base import BaseHook +from airflow.io.store import ObjectStore +from airflow.providers_manager import ProvidersManager +from airflow.utils.log.logging_mixin import LoggingMixin + +# Store context what sent lineage. +LineageContext = Union[BaseHook, ObjectStore] + +_hook_lineage_collector: HookLineageCollector | None = None + + +@attr.define +class HookLineage: + """Holds lineage collected by HookLineageCollector.""" + + inputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list) + outputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list) + + +class HookLineageCollector(LoggingMixin): + """ + HookLineageCollector is a base class for collecting hook lineage information. + + It is used to collect the input and output datasets of a hook execution. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.inputs: list[tuple[Dataset, LineageContext]] = [] + self.outputs: list[tuple[Dataset, LineageContext]] = [] + + def create_dataset( + self, scheme: str | None, uri: str | None, dataset_kwargs: dict | None, dataset_extra: dict | None + ) -> Dataset | None: + """ + Create a Dataset instance using the provided parameters. + + This method attempts to create a Dataset instance using the given parameters. + It first checks if a URI is provided and falls back to using the default dataset factory + with the given URI if no other information is available. + + If a scheme is provided but no URI, it attempts to find a dataset factory that matches + the given scheme. If no such factory is found, it logs an error message and returns None. + + If dataset_kwargs is provided, it is used to pass additional parameters to the Dataset + factory. The dataset_extra parameter is also passed to the factory as an ``extra`` parameter. + """ + if uri: + # Fallback to default factory using the provided URI + return Dataset(uri=uri, extra=dataset_extra) + + if not scheme: + self.log.debug( + "Missing required parameter: either 'uri' or 'scheme' must be provided to create a Dataset." + ) + return None + + dataset_factory = ProvidersManager().dataset_factories.get(scheme) + if not dataset_factory: + self.log.debug("Unsupported scheme: %s. Please provide a valid URI to create a Dataset.", scheme) + return None + + dataset_kwargs = dataset_kwargs or {} + try: + return dataset_factory(**dataset_kwargs, extra=dataset_extra) + except Exception as e: + self.log.debug("Failed to create dataset. Skipping. Error: %s", e) + return None + + def add_input_dataset( + self, + context: LineageContext, + scheme: str | None = None, + uri: str | None = None, + dataset_kwargs: dict | None = None, + dataset_extra: dict | None = None, + ): + """Add the input dataset and its corresponding hook execution context to the collector.""" + dataset = self.create_dataset( + scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, dataset_extra=dataset_extra + ) + if dataset: + self.inputs.append((dataset, context)) + + def add_output_dataset( + self, + context: LineageContext, + scheme: str | None = None, + uri: str | None = None, + dataset_kwargs: dict | None = None, + dataset_extra: dict | None = None, + ): + """Add the output dataset and its corresponding hook execution context to the collector.""" + dataset = self.create_dataset( + scheme=scheme, uri=uri, dataset_kwargs=dataset_kwargs, dataset_extra=dataset_extra + ) + if dataset: + self.outputs.append((dataset, context)) + + @property + def collected_datasets(self) -> HookLineage: + """Get the collected hook lineage information.""" + return HookLineage(self.inputs, self.outputs) + + @property + def has_collected(self) -> bool: + """Check if any datasets have been collected.""" + return len(self.inputs) != 0 or len(self.outputs) != 0 + + +class NoOpCollector(HookLineageCollector): + """ + NoOpCollector is a hook lineage collector that does nothing. + + It is used when you want to disable lineage collection. + """ + + def add_input_dataset(self, *_): + pass + + def add_output_dataset(self, *_): + pass + + @property + def collected_datasets( + self, + ) -> HookLineage: + self.log.warning( + "Data lineage tracking is disabled. Register a hook lineage reader to start tracking hook lineage." + ) + return HookLineage([], []) + + +class HookLineageReader(LoggingMixin): + """Class used to retrieve the hook lineage information collected by HookLineageCollector.""" + + def __init__(self, **kwargs): + self.lineage_collector = get_hook_lineage_collector() + + def retrieve_hook_lineage(self) -> HookLineage: + """Retrieve hook lineage from HookLineageCollector.""" + hook_lineage = self.lineage_collector.collected_datasets + return hook_lineage + + +def get_hook_lineage_collector() -> HookLineageCollector: + """Get singleton lineage collector.""" + global _hook_lineage_collector + if not _hook_lineage_collector: + from airflow import plugins_manager + + plugins_manager.initialize_hook_lineage_readers_plugins() + if plugins_manager.hook_lineage_reader_classes: + _hook_lineage_collector = HookLineageCollector() + else: + _hook_lineage_collector = NoOpCollector() + return _hook_lineage_collector diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 63b3dbd80d47a5..76d72a45850093 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -27,6 +27,7 @@ import os import sys import types +from cgitb import Hook from pathlib import Path from typing import TYPE_CHECKING, Any, Iterable @@ -41,6 +42,8 @@ from airflow.utils.module_loading import import_string, qualname if TYPE_CHECKING: + from airflow.lineage.hook import HookLineageReader + try: import importlib_metadata as metadata except ImportError: @@ -75,6 +78,7 @@ registered_operator_link_classes: dict[str, type] | None = None registered_ti_dep_classes: dict[str, type] | None = None timetable_classes: dict[str, type[Timetable]] | None = None +hook_lineage_reader_classes: list[type[Hook]] | None = None priority_weight_strategy_classes: dict[str, type[PriorityWeightStrategy]] | None = None """ Mapping of class names to class of OperatorLinks registered by plugins. @@ -176,8 +180,12 @@ class AirflowPlugin: # A list of timetable classes that can be used for DAG scheduling. timetables: list[type[Timetable]] = [] + # A list of listeners that can be used for tracking task and DAG states. listeners: list[ModuleType | object] = [] + # A list of hook lineage reader classes that can be used for reading lineage information from a hook. + hook_lineage_readers: list[type[HookLineageReader]] = [] + # A list of priority weight strategy classes that can be used for calculating tasks weight priority. priority_weight_strategies: list[type[PriorityWeightStrategy]] = [] @@ -483,6 +491,25 @@ def initialize_timetables_plugins(): } +def initialize_hook_lineage_readers_plugins(): + """Collect hook lineage reader classes registered by plugins.""" + global hook_lineage_reader_classes + + if hook_lineage_reader_classes is not None: + return + + ensure_plugins_loaded() + + if plugins is None: + raise AirflowPluginException("Can't load plugins.") + + log.debug("Initialize hook lineage readers plugins") + + hook_lineage_reader_classes = [] + for plugin in plugins: + hook_lineage_reader_classes.extend(plugin.hook_lineage_readers) + + def integrate_executor_plugins() -> None: """Integrate executor plugins to the context.""" global plugins diff --git a/airflow/provider.yaml.schema.json b/airflow/provider.yaml.schema.json index 3e5e71759e2001..adbca7846d19e8 100644 --- a/airflow/provider.yaml.schema.json +++ b/airflow/provider.yaml.schema.json @@ -212,6 +212,10 @@ "handler": { "type": ["string", "null"], "description": "Normalization function for specified URI schemes. Import path to a callable taking and returning a SplitResult. 'null' specifies a no-op." + }, + "factory": { + "type": ["string", "null"], + "description": "Dataset factory for specified URI. Creates AIP-60 compliant Dataset." } } } diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index ec424e14897ecc..9e9dd4d573ddd7 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -91,6 +91,7 @@ def ensure_prefix(field): if TYPE_CHECKING: from urllib.parse import SplitResult + from airflow.datasets import Dataset from airflow.decorators.base import TaskDecorator from airflow.hooks.base import BaseHook from airflow.typing_compat import Literal @@ -426,6 +427,7 @@ def __init__(self): self._hooks_dict: dict[str, HookInfo] = {} self._fs_set: set[str] = set() self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {} + self._dataset_factories: dict[str, Callable[..., Dataset]] = {} self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment] # keeps mapping between connection_types and hook class, package they come from self._hook_provider_dict: dict[str, HookClassProvider] = {} @@ -523,11 +525,12 @@ def initialize_providers_filesystems(self): self._discover_filesystems() @provider_info_cache("dataset_uris") - def initialize_providers_dataset_uri_handlers(self): + def initialize_providers_dataset_uri_handlers_and_factories(self): """Lazy initialization of provider dataset URI handlers.""" self.initialize_providers_list() - self._discover_dataset_uri_handlers() + self._discover_dataset_uri_handlers_and_factories() + @provider_info_cache("hook_lineage_writers") @provider_info_cache("taskflow_decorators") def initialize_providers_taskflow_decorator(self): """Lazy initialization of providers hooks.""" @@ -878,7 +881,7 @@ def _discover_filesystems(self) -> None: self._fs_set.add(fs_module_name) self._fs_set = set(sorted(self._fs_set)) - def _discover_dataset_uri_handlers(self) -> None: + def _discover_dataset_uri_handlers_and_factories(self) -> None: from airflow.datasets import normalize_noop for provider_package, provider in self._provider_dict.items(): @@ -893,6 +896,13 @@ def _discover_dataset_uri_handlers(self) -> None: elif not (handler := _correctness_check(provider_package, handler_path, provider)): continue self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes) + factory_path = handler_info.get("factory") + if not ( + factory_path is not None + and (factory := _correctness_check(provider_package, factory_path, provider)) + ): + continue + self._dataset_factories.update((scheme, factory) for scheme in schemes) def _discover_taskflow_decorators(self) -> None: for name, info in self._provider_dict.items(): @@ -1289,9 +1299,14 @@ def filesystem_module_names(self) -> list[str]: self.initialize_providers_filesystems() return sorted(self._fs_set) + @property + def dataset_factories(self) -> dict[str, Callable[..., Dataset]]: + self.initialize_providers_dataset_uri_handlers_and_factories() + return self._dataset_factories + @property def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]: - self.initialize_providers_dataset_uri_handlers() + self.initialize_providers_dataset_uri_handlers_and_factories() return self._dataset_uri_handlers @property diff --git a/docs/apache-airflow/administration-and-deployment/lineage.rst b/docs/apache-airflow/administration-and-deployment/lineage.rst index 7b967d3f10e61a..382aaa36d598f7 100644 --- a/docs/apache-airflow/administration-and-deployment/lineage.rst +++ b/docs/apache-airflow/administration-and-deployment/lineage.rst @@ -89,6 +89,49 @@ has outlets defined (e.g. by using ``add_outlets(..)`` or has out of the box sup .. _precedence: https://docs.python.org/3/reference/expressions.html +Hook Lineage +------------ + +Airflow provides a powerful feature for tracking data lineage not only between tasks but also from hooks used within those tasks. +This functionality helps you understand how data flows throughout your Airflow pipelines. + +A global instance of ``HookLineageCollector`` serves as the central hub for collecting lineage information. +Hooks can send details about datasets they interact with to this collector. +The collector then uses this data to construct AIP-60 compliant Datasets, a standard format for describing datasets. + +.. code-block:: python + + from airflow.lineage.hook_lineage import get_hook_lineage_collector + + + class CustomHook(BaseHook): + def run(self): + # run actual code + collector = get_hook_lineage_collector() + collector.add_input_dataset(self, dataset_kwargs={"scheme": "file", "path": "/tmp/in"}) + collector.add_output_dataset(self, dataset_kwargs={"scheme": "file", "path": "/tmp/out"}) + +Lineage data collected by the ``HookLineageCollector`` can be accessed using an instance of ``HookLineageReader``, +which is registered in an Airflow plugin. + +.. code-block:: python + + from airflow.lineage.hook_lineage import HookLineageReader + from airflow.plugins_manager import AirflowPlugin + + + class CustomHookLineageReader(HookLineageReader): + def get_inputs(self): + return self.lineage_collector.collected_datasets.inputs + + + class HookLineageCollectionPlugin(AirflowPlugin): + name = "HookLineageCollectionPlugin" + hook_lineage_readers = [CustomHookLineageReader] + +If no ``HookLineageReader`` is registered within Airflow, a default ``NoOpCollector`` is used instead. +This collector does not create AIP-60 compliant datasets or collect lineage information. + Lineage Backend --------------- diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py new file mode 100644 index 00000000000000..15d4d6c1e4cc1c --- /dev/null +++ b/tests/lineage/test_hook.py @@ -0,0 +1,120 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow import plugins_manager +from airflow.datasets import Dataset +from airflow.hooks.base import BaseHook +from airflow.lineage import hook +from airflow.lineage.hook import ( + HookLineage, + HookLineageCollector, + HookLineageReader, + NoOpCollector, + get_hook_lineage_collector, +) +from tests.test_utils.mock_plugins import mock_plugin_manager + + +class TestHookLineageCollector: + def test_are_datasets_collected(self): + lineage_collector = HookLineageCollector() + assert lineage_collector is not None + assert lineage_collector.collected_datasets == HookLineage() + input_hook = BaseHook() + output_hook = BaseHook() + lineage_collector.add_input_dataset(input_hook, uri="s3://in_bucket/file") + lineage_collector.add_output_dataset( + output_hook, uri="postgres://example.com:5432/database/default/table" + ) + assert lineage_collector.collected_datasets == HookLineage( + [(Dataset("s3://in_bucket/file"), input_hook)], + [(Dataset("postgres://example.com:5432/database/default/table"), output_hook)], + ) + + @patch("airflow.lineage.hook.Dataset") + def test_add_input_dataset(self, mock_dataset): + collector = HookLineageCollector() + dataset = MagicMock(spec=Dataset) + mock_dataset.return_value = dataset + + hook = MagicMock() + collector.add_input_dataset(hook, uri="test_uri") + + assert collector.inputs == [(dataset, hook)] + mock_dataset.assert_called_once_with(uri="test_uri", extra=None) + + @patch("airflow.lineage.hook.ProvidersManager") + def test_create_dataset(self, mock_providers_manager): + def create_dataset(arg1, arg2="default", extra=None): + return Dataset(uri=f"myscheme://{arg1}/{arg2}", extra=extra) + + mock_providers_manager.return_value.dataset_factories = {"myscheme": create_dataset} + collector = HookLineageCollector() + assert collector.create_dataset( + scheme="myscheme", uri=None, dataset_kwargs={"arg1": "value_1"}, dataset_extra=None + ) == Dataset("myscheme://value_1/default") + assert collector.create_dataset( + scheme="myscheme", + uri=None, + dataset_kwargs={"arg1": "value_1", "arg2": "value_2"}, + dataset_extra={"key": "value"}, + ) == Dataset("myscheme://value_1/value_2", extra={"key": "value"}) + + def test_collected_datasets(self): + collector = HookLineageCollector() + inputs = [(MagicMock()), MagicMock(spec=Dataset)] + outputs = [(MagicMock()), MagicMock(spec=Dataset)] + collector.inputs = inputs + collector.outputs = outputs + + hook_lineage = collector.collected_datasets + assert hook_lineage.inputs == inputs + assert hook_lineage.outputs == outputs + + def test_has_collected(self): + collector = HookLineageCollector() + assert not collector.has_collected + + collector.inputs = [MagicMock(spec=Dataset), MagicMock()] + assert collector.has_collected + + +class FakePlugin(plugins_manager.AirflowPlugin): + name = "FakePluginHavingHookLineageCollector" + hook_lineage_readers = [HookLineageReader] + + +@pytest.mark.parametrize( + "has_readers, expected_class", + [ + (True, HookLineageCollector), + (False, NoOpCollector), + ], +) +def test_get_hook_lineage_collector(has_readers, expected_class): + # reset global variable + hook._hook_lineage_collector = None + plugins = [FakePlugin()] if has_readers else [] + with mock_plugin_manager(plugins=plugins): + assert isinstance(get_hook_lineage_collector(), expected_class) + assert get_hook_lineage_collector() is get_hook_lineage_collector() diff --git a/tests/test_utils/mock_plugins.py b/tests/test_utils/mock_plugins.py index 3242f36159877e..3ab1acb730b9c0 100644 --- a/tests/test_utils/mock_plugins.py +++ b/tests/test_utils/mock_plugins.py @@ -33,6 +33,7 @@ "operator_extra_links", "registered_operator_link_classes", "timetable_classes", + "hook_lineage_reader_classes", ]