From 1293a079bd64fcdb328121dbdd3635bbba2428d1 Mon Sep 17 00:00:00 2001 From: Jakub Dardzinski Date: Wed, 19 Jun 2024 13:35:44 +0200 Subject: [PATCH] Add HookLineageCollector that during task execution should register and hold lineage sent from hooks. Add HookLineageReader that defines whether HookLineageCollector should be enabled to process lineage sent from hooks. Add Dataset factories to make sure Datasets registered with HookLineageCollector is AIP-60 compliant. Signed-off-by: Jakub Dardzinski --- airflow/datasets/__init__.py | 5 ++ airflow/lineage/hook.py | 140 ++++++++++++++++++++++++++++++ airflow/provider.yaml.schema.json | 11 +++ airflow/providers_manager.py | 43 +++++++-- tests/lineage/test_hook.py | 108 +++++++++++++++++++++++ 5 files changed, 302 insertions(+), 5 deletions(-) create mode 100644 airflow/lineage/hook.py create mode 100644 tests/lineage/test_hook.py diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 4de148de84fab..9136c29b2c4ff 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -41,6 +41,11 @@ def normalize_noop(parts: SplitResult) -> SplitResult: return parts +def create_dataset(uri: str) -> Dataset: + """Create a dataset object from a dataset URI.""" + return Dataset(uri=uri) + + def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: if scheme == "file": return normalize_noop diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py new file mode 100644 index 0000000000000..9510f1b2d4a11 --- /dev/null +++ b/airflow/lineage/hook.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Union + +import attr + +from airflow.datasets import Dataset, create_dataset +from airflow.hooks.base import BaseHook +from airflow.io.store import ObjectStore +from airflow.providers_manager import ProvidersManager +from airflow.utils.log.logging_mixin import LoggingMixin + +# Store context what sent lineage. +LineageContext = Union[BaseHook, ObjectStore] + +_hook_lineage_collector: HookLineageCollector | None = None + + +@attr.define +class HookLineage: + """Holds lineage collected by HookLineageCollector.""" + + inputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list) + outputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list) + + +class HookLineageCollector(LoggingMixin): + """ + HookLineageCollector is a base class for collecting hook lineage information. + + It is used to collect the input and output datasets of a hook execution. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.inputs: list[tuple[Dataset, LineageContext]] = [] + self.outputs: list[tuple[Dataset, LineageContext]] = [] + + @staticmethod + def create_dataset(dataset_kwargs: dict) -> Dataset: + """Create a Dataset instance from the given dataset kwargs.""" + if "uri" in dataset_kwargs: + # Fallback to default factory using the provided URI + return create_dataset(dataset_kwargs["uri"]) + + scheme: str = dataset_kwargs.pop("scheme", None) + if not scheme: + raise ValueError( + "Missing required parameter: either 'uri' or 'scheme' must be provided to create a Dataset." + ) + + dataset_factory = ProvidersManager().dataset_factories.get(scheme) + if not dataset_factory: + raise ValueError( + f"Unsupported scheme: '{scheme}'. Please provide a valid URI to create a Dataset." + ) + + return dataset_factory(**dataset_kwargs) + + def add_input_dataset(self, dataset_kwargs: dict, hook: LineageContext): + """Add the input dataset and its corresponding hook execution context to the collector.""" + dataset = self.create_dataset(dataset_kwargs) + self.inputs.append((dataset, hook)) + + def add_output_dataset(self, dataset_kwargs: dict, hook: LineageContext): + """Add the output dataset and its corresponding hook execution context to the collector.""" + dataset = self.create_dataset(dataset_kwargs) + self.outputs.append((dataset, hook)) + + @property + def collected_datasets(self) -> HookLineage: + """Get the collected hook lineage information.""" + return HookLineage(self.inputs, self.outputs) + + @property + def has_collected(self) -> bool: + """Check if any datasets have been collected.""" + return len(self.inputs) != 0 or len(self.outputs) != 0 + + +class NoOpCollector(HookLineageCollector): + """ + NoOpCollector is a hook lineage collector that does nothing. + + It is used when you want to disable lineage collection. + """ + + def add_input_dataset(self, *_): + pass + + def add_output_dataset(self, *_): + pass + + @property + def collected_datasets( + self, + ) -> HookLineage: + self.log.warning("You should not call this as there's no reader.") + return HookLineage([], []) + + +class HookLineageReader(LoggingMixin): + """Class used to retrieve the hook lineage information collected by HookLineageCollector.""" + + def __init__(self, **kwargs): + self.lineage_collector = get_hook_lineage_collector() + + def retrieve_hook_lineage(self) -> HookLineage: + """Retrieve hook lineage from HookLineageCollector.""" + hook_lineage = self.lineage_collector.collected_datasets + return hook_lineage + + +def get_hook_lineage_collector() -> HookLineageCollector: + """Get singleton lineage collector.""" + global _hook_lineage_collector + if not _hook_lineage_collector: + # is there a better why how to use noop? + if ProvidersManager().hook_lineage_readers: + _hook_lineage_collector = HookLineageCollector() + else: + _hook_lineage_collector = NoOpCollector() + return _hook_lineage_collector diff --git a/airflow/provider.yaml.schema.json b/airflow/provider.yaml.schema.json index 3e5e71759e200..0498ad916ab83 100644 --- a/airflow/provider.yaml.schema.json +++ b/airflow/provider.yaml.schema.json @@ -212,10 +212,21 @@ "handler": { "type": ["string", "null"], "description": "Normalization function for specified URI schemes. Import path to a callable taking and returning a SplitResult. 'null' specifies a no-op." + }, + "factory": { + "type": ["string", "null"], + "description": "Dataset factory for specified URI. Creates AIP-60 compliant Dataset." } } } }, + "hook-lineage-readers": { + "type": "array", + "description": "Hook lineage readers", + "items": { + "type": "string" + } + }, "transfers": { "type": "array", "items": { diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index ec424e14897ec..eeaaa3e083c7c 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -91,6 +91,7 @@ def ensure_prefix(field): if TYPE_CHECKING: from urllib.parse import SplitResult + from airflow.datasets import Dataset from airflow.decorators.base import TaskDecorator from airflow.hooks.base import BaseHook from airflow.typing_compat import Literal @@ -426,6 +427,8 @@ def __init__(self): self._hooks_dict: dict[str, HookInfo] = {} self._fs_set: set[str] = set() self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {} + self._dataset_factories: dict[str, Callable[..., Dataset]] = {} + self._hook_lineage_readers: set[str] = set() self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment] # keeps mapping between connection_types and hook class, package they come from self._hook_provider_dict: dict[str, HookClassProvider] = {} @@ -526,8 +529,15 @@ def initialize_providers_filesystems(self): def initialize_providers_dataset_uri_handlers(self): """Lazy initialization of provider dataset URI handlers.""" self.initialize_providers_list() - self._discover_dataset_uri_handlers() + self._discover_dataset_uri_handlers_and_factories() + @provider_info_cache("hook_lineage_readers") + def initialize_providers_hook_lineage_readers(self): + """Lazy initialization of providers hook lineage readers.""" + self.initialize_providers_list() + self._discover_hook_lineage_readers() + + @provider_info_cache("hook_lineage_writers") @provider_info_cache("taskflow_decorators") def initialize_providers_taskflow_decorator(self): """Lazy initialization of providers hooks.""" @@ -564,7 +574,7 @@ def initialize_providers_notifications(self): self.initialize_providers_list() self._discover_notifications() - @provider_info_cache("auth_managers") + @provider_info_cache(cache_name="auth_managers") def initialize_providers_auth_managers(self): """Lazy initialization of providers notifications information.""" self.initialize_providers_list() @@ -878,21 +888,34 @@ def _discover_filesystems(self) -> None: self._fs_set.add(fs_module_name) self._fs_set = set(sorted(self._fs_set)) - def _discover_dataset_uri_handlers(self) -> None: - from airflow.datasets import normalize_noop + def _discover_dataset_uri_handlers_and_factories(self) -> None: + from airflow.datasets import create_dataset, normalize_noop for provider_package, provider in self._provider_dict.items(): for handler_info in provider.data.get("dataset-uris", []): try: schemes = handler_info["schemes"] handler_path = handler_info["handler"] + factory_path = handler_info["factory"] except KeyError: continue if handler_path is None: handler = normalize_noop - elif not (handler := _correctness_check(provider_package, handler_path, provider)): + if factory_path is None: + factory = create_dataset + elif not (handler := _correctness_check(provider_package, handler_path, provider)) or not ( + factory := _correctness_check(provider_package, factory_path, provider) + ): continue self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes) + self._dataset_factories.update((scheme, factory) for scheme in schemes) + + def _discover_hook_lineage_readers(self) -> None: + for provider_package, provider in self._provider_dict.items(): + for hook_lineage_reader in provider.data.get("hook-lineage-readers", []): + if _correctness_check(provider_package, hook_lineage_reader, provider): + self._hook_lineage_readers.add(hook_lineage_reader) + self._fs_set = set(sorted(self._fs_set)) def _discover_taskflow_decorators(self) -> None: for name, info in self._provider_dict.items(): @@ -1289,11 +1312,21 @@ def filesystem_module_names(self) -> list[str]: self.initialize_providers_filesystems() return sorted(self._fs_set) + @property + def dataset_factories(self) -> dict[str, Callable[..., Dataset]]: + self.initialize_providers_dataset_uri_handlers() + return self._dataset_factories + @property def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]: self.initialize_providers_dataset_uri_handlers() return self._dataset_uri_handlers + @property + def hook_lineage_readers(self) -> list[str]: + self.initialize_providers_hook_lineage_readers() + return sorted(self._hook_lineage_readers) + @property def provider_configs(self) -> list[tuple[str, dict[str, Any]]]: self.initialize_providers_configuration() diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py new file mode 100644 index 0000000000000..876b6e04109cd --- /dev/null +++ b/tests/lineage/test_hook.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.datasets import Dataset +from airflow.hooks.base import BaseHook +from airflow.lineage import hook +from airflow.lineage.hook import HookLineage, HookLineageCollector, NoOpCollector, get_hook_lineage_collector + + +class TestHookLineageCollector: + def test_are_datasets_collected(self): + lineage_collector = HookLineageCollector() + assert lineage_collector is not None + assert lineage_collector.collected_datasets == HookLineage() + input_hook = BaseHook() + output_hook = BaseHook() + lineage_collector.add_input_dataset({"uri": "s3://in_bucket/file"}, input_hook) + lineage_collector.add_output_dataset( + {"uri": "postgres://example.com:5432/database/default/table"}, output_hook + ) + assert lineage_collector.collected_datasets == HookLineage( + [(Dataset("s3://in_bucket/file"), input_hook)], + [(Dataset("postgres://example.com:5432/database/default/table"), output_hook)], + ) + + @patch("airflow.lineage.hook.create_dataset") + def test_add_input_dataset(self, mock_create_dataset): + collector = HookLineageCollector() + mock_dataset = MagicMock(spec=Dataset) + mock_create_dataset.return_value = mock_dataset + + dataset_kwargs = {"uri": "test_uri"} + hook = MagicMock() + collector.add_input_dataset(dataset_kwargs, hook) + + assert collector.inputs == [(mock_dataset, hook)] + mock_create_dataset.assert_called_once_with("test_uri") + + @patch("airflow.lineage.hook.ProvidersManager") + def test_create_dataset(self, mock_providers_manager): + def create_dataset(arg1, arg2="default"): + return Dataset(uri=f"myscheme://{arg1}/{arg2}") + + mock_providers_manager.return_value.dataset_factories = {"myscheme": create_dataset} + collector = HookLineageCollector() + assert collector.create_dataset({"scheme": "myscheme", "arg1": "value_1"}) == Dataset( + "myscheme://value_1/default" + ) + assert collector.create_dataset( + {"scheme": "myscheme", "arg1": "value_1", "arg2": "value_2"} + ) == Dataset("myscheme://value_1/value_2") + + def test_collected_datasets(self): + collector = HookLineageCollector() + inputs = [(MagicMock(spec=Dataset), MagicMock())] + outputs = [(MagicMock(spec=Dataset), MagicMock())] + collector.inputs = inputs + collector.outputs = outputs + + hook_lineage = collector.collected_datasets + assert hook_lineage.inputs == inputs + assert hook_lineage.outputs == outputs + + def test_has_collected(self): + collector = HookLineageCollector() + assert not collector.has_collected + + collector.inputs = [MagicMock(spec=Dataset), MagicMock()] + assert collector.has_collected + + +@pytest.mark.parametrize( + "has_readers, expected_class", + [ + (True, HookLineageCollector), + (False, NoOpCollector), + ], +) +@patch("airflow.lineage.hook.ProvidersManager") +def test_get_hook_lineage_collector(mock_providers_manager, has_readers, expected_class): + # reset global variable + hook._hook_lineage_collector = None + if has_readers: + mock_providers_manager.return_value.hook_lineage_readers = [MagicMock()] + else: + mock_providers_manager.return_value.hook_lineage_readers = [] + assert isinstance(get_hook_lineage_collector(), expected_class) + assert get_hook_lineage_collector() is get_hook_lineage_collector()