diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index 4de148de84fabd..9136c29b2c4ffb 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -41,6 +41,11 @@ def normalize_noop(parts: SplitResult) -> SplitResult: return parts +def create_dataset(uri: str) -> Dataset: + """Create a dataset object from a dataset URI.""" + return Dataset(uri=uri) + + def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None: if scheme == "file": return normalize_noop diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py new file mode 100644 index 00000000000000..9510f1b2d4a111 --- /dev/null +++ b/airflow/lineage/hook.py @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Union + +import attr + +from airflow.datasets import Dataset, create_dataset +from airflow.hooks.base import BaseHook +from airflow.io.store import ObjectStore +from airflow.providers_manager import ProvidersManager +from airflow.utils.log.logging_mixin import LoggingMixin + +# Store context what sent lineage. +LineageContext = Union[BaseHook, ObjectStore] + +_hook_lineage_collector: HookLineageCollector | None = None + + +@attr.define +class HookLineage: + """Holds lineage collected by HookLineageCollector.""" + + inputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list) + outputs: list[tuple[Dataset, LineageContext]] = attr.ib(factory=list) + + +class HookLineageCollector(LoggingMixin): + """ + HookLineageCollector is a base class for collecting hook lineage information. + + It is used to collect the input and output datasets of a hook execution. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.inputs: list[tuple[Dataset, LineageContext]] = [] + self.outputs: list[tuple[Dataset, LineageContext]] = [] + + @staticmethod + def create_dataset(dataset_kwargs: dict) -> Dataset: + """Create a Dataset instance from the given dataset kwargs.""" + if "uri" in dataset_kwargs: + # Fallback to default factory using the provided URI + return create_dataset(dataset_kwargs["uri"]) + + scheme: str = dataset_kwargs.pop("scheme", None) + if not scheme: + raise ValueError( + "Missing required parameter: either 'uri' or 'scheme' must be provided to create a Dataset." + ) + + dataset_factory = ProvidersManager().dataset_factories.get(scheme) + if not dataset_factory: + raise ValueError( + f"Unsupported scheme: '{scheme}'. Please provide a valid URI to create a Dataset." + ) + + return dataset_factory(**dataset_kwargs) + + def add_input_dataset(self, dataset_kwargs: dict, hook: LineageContext): + """Add the input dataset and its corresponding hook execution context to the collector.""" + dataset = self.create_dataset(dataset_kwargs) + self.inputs.append((dataset, hook)) + + def add_output_dataset(self, dataset_kwargs: dict, hook: LineageContext): + """Add the output dataset and its corresponding hook execution context to the collector.""" + dataset = self.create_dataset(dataset_kwargs) + self.outputs.append((dataset, hook)) + + @property + def collected_datasets(self) -> HookLineage: + """Get the collected hook lineage information.""" + return HookLineage(self.inputs, self.outputs) + + @property + def has_collected(self) -> bool: + """Check if any datasets have been collected.""" + return len(self.inputs) != 0 or len(self.outputs) != 0 + + +class NoOpCollector(HookLineageCollector): + """ + NoOpCollector is a hook lineage collector that does nothing. + + It is used when you want to disable lineage collection. + """ + + def add_input_dataset(self, *_): + pass + + def add_output_dataset(self, *_): + pass + + @property + def collected_datasets( + self, + ) -> HookLineage: + self.log.warning("You should not call this as there's no reader.") + return HookLineage([], []) + + +class HookLineageReader(LoggingMixin): + """Class used to retrieve the hook lineage information collected by HookLineageCollector.""" + + def __init__(self, **kwargs): + self.lineage_collector = get_hook_lineage_collector() + + def retrieve_hook_lineage(self) -> HookLineage: + """Retrieve hook lineage from HookLineageCollector.""" + hook_lineage = self.lineage_collector.collected_datasets + return hook_lineage + + +def get_hook_lineage_collector() -> HookLineageCollector: + """Get singleton lineage collector.""" + global _hook_lineage_collector + if not _hook_lineage_collector: + # is there a better why how to use noop? + if ProvidersManager().hook_lineage_readers: + _hook_lineage_collector = HookLineageCollector() + else: + _hook_lineage_collector = NoOpCollector() + return _hook_lineage_collector diff --git a/airflow/provider.yaml.schema.json b/airflow/provider.yaml.schema.json index 3e5e71759e2001..0498ad916ab837 100644 --- a/airflow/provider.yaml.schema.json +++ b/airflow/provider.yaml.schema.json @@ -212,10 +212,21 @@ "handler": { "type": ["string", "null"], "description": "Normalization function for specified URI schemes. Import path to a callable taking and returning a SplitResult. 'null' specifies a no-op." + }, + "factory": { + "type": ["string", "null"], + "description": "Dataset factory for specified URI. Creates AIP-60 compliant Dataset." } } } }, + "hook-lineage-readers": { + "type": "array", + "description": "Hook lineage readers", + "items": { + "type": "string" + } + }, "transfers": { "type": "array", "items": { diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index ec424e14897ecc..eeaaa3e083c7c4 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -91,6 +91,7 @@ def ensure_prefix(field): if TYPE_CHECKING: from urllib.parse import SplitResult + from airflow.datasets import Dataset from airflow.decorators.base import TaskDecorator from airflow.hooks.base import BaseHook from airflow.typing_compat import Literal @@ -426,6 +427,8 @@ def __init__(self): self._hooks_dict: dict[str, HookInfo] = {} self._fs_set: set[str] = set() self._dataset_uri_handlers: dict[str, Callable[[SplitResult], SplitResult]] = {} + self._dataset_factories: dict[str, Callable[..., Dataset]] = {} + self._hook_lineage_readers: set[str] = set() self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache() # type: ignore[assignment] # keeps mapping between connection_types and hook class, package they come from self._hook_provider_dict: dict[str, HookClassProvider] = {} @@ -526,8 +529,15 @@ def initialize_providers_filesystems(self): def initialize_providers_dataset_uri_handlers(self): """Lazy initialization of provider dataset URI handlers.""" self.initialize_providers_list() - self._discover_dataset_uri_handlers() + self._discover_dataset_uri_handlers_and_factories() + @provider_info_cache("hook_lineage_readers") + def initialize_providers_hook_lineage_readers(self): + """Lazy initialization of providers hook lineage readers.""" + self.initialize_providers_list() + self._discover_hook_lineage_readers() + + @provider_info_cache("hook_lineage_writers") @provider_info_cache("taskflow_decorators") def initialize_providers_taskflow_decorator(self): """Lazy initialization of providers hooks.""" @@ -564,7 +574,7 @@ def initialize_providers_notifications(self): self.initialize_providers_list() self._discover_notifications() - @provider_info_cache("auth_managers") + @provider_info_cache(cache_name="auth_managers") def initialize_providers_auth_managers(self): """Lazy initialization of providers notifications information.""" self.initialize_providers_list() @@ -878,21 +888,34 @@ def _discover_filesystems(self) -> None: self._fs_set.add(fs_module_name) self._fs_set = set(sorted(self._fs_set)) - def _discover_dataset_uri_handlers(self) -> None: - from airflow.datasets import normalize_noop + def _discover_dataset_uri_handlers_and_factories(self) -> None: + from airflow.datasets import create_dataset, normalize_noop for provider_package, provider in self._provider_dict.items(): for handler_info in provider.data.get("dataset-uris", []): try: schemes = handler_info["schemes"] handler_path = handler_info["handler"] + factory_path = handler_info["factory"] except KeyError: continue if handler_path is None: handler = normalize_noop - elif not (handler := _correctness_check(provider_package, handler_path, provider)): + if factory_path is None: + factory = create_dataset + elif not (handler := _correctness_check(provider_package, handler_path, provider)) or not ( + factory := _correctness_check(provider_package, factory_path, provider) + ): continue self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes) + self._dataset_factories.update((scheme, factory) for scheme in schemes) + + def _discover_hook_lineage_readers(self) -> None: + for provider_package, provider in self._provider_dict.items(): + for hook_lineage_reader in provider.data.get("hook-lineage-readers", []): + if _correctness_check(provider_package, hook_lineage_reader, provider): + self._hook_lineage_readers.add(hook_lineage_reader) + self._fs_set = set(sorted(self._fs_set)) def _discover_taskflow_decorators(self) -> None: for name, info in self._provider_dict.items(): @@ -1289,11 +1312,21 @@ def filesystem_module_names(self) -> list[str]: self.initialize_providers_filesystems() return sorted(self._fs_set) + @property + def dataset_factories(self) -> dict[str, Callable[..., Dataset]]: + self.initialize_providers_dataset_uri_handlers() + return self._dataset_factories + @property def dataset_uri_handlers(self) -> dict[str, Callable[[SplitResult], SplitResult]]: self.initialize_providers_dataset_uri_handlers() return self._dataset_uri_handlers + @property + def hook_lineage_readers(self) -> list[str]: + self.initialize_providers_hook_lineage_readers() + return sorted(self._hook_lineage_readers) + @property def provider_configs(self) -> list[tuple[str, dict[str, Any]]]: self.initialize_providers_configuration() diff --git a/tests/lineage/test_hook.py b/tests/lineage/test_hook.py new file mode 100644 index 00000000000000..876b6e04109cda --- /dev/null +++ b/tests/lineage/test_hook.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.datasets import Dataset +from airflow.hooks.base import BaseHook +from airflow.lineage import hook +from airflow.lineage.hook import HookLineage, HookLineageCollector, NoOpCollector, get_hook_lineage_collector + + +class TestHookLineageCollector: + def test_are_datasets_collected(self): + lineage_collector = HookLineageCollector() + assert lineage_collector is not None + assert lineage_collector.collected_datasets == HookLineage() + input_hook = BaseHook() + output_hook = BaseHook() + lineage_collector.add_input_dataset({"uri": "s3://in_bucket/file"}, input_hook) + lineage_collector.add_output_dataset( + {"uri": "postgres://example.com:5432/database/default/table"}, output_hook + ) + assert lineage_collector.collected_datasets == HookLineage( + [(Dataset("s3://in_bucket/file"), input_hook)], + [(Dataset("postgres://example.com:5432/database/default/table"), output_hook)], + ) + + @patch("airflow.lineage.hook.create_dataset") + def test_add_input_dataset(self, mock_create_dataset): + collector = HookLineageCollector() + mock_dataset = MagicMock(spec=Dataset) + mock_create_dataset.return_value = mock_dataset + + dataset_kwargs = {"uri": "test_uri"} + hook = MagicMock() + collector.add_input_dataset(dataset_kwargs, hook) + + assert collector.inputs == [(mock_dataset, hook)] + mock_create_dataset.assert_called_once_with("test_uri") + + @patch("airflow.lineage.hook.ProvidersManager") + def test_create_dataset(self, mock_providers_manager): + def create_dataset(arg1, arg2="default"): + return Dataset(uri=f"myscheme://{arg1}/{arg2}") + + mock_providers_manager.return_value.dataset_factories = {"myscheme": create_dataset} + collector = HookLineageCollector() + assert collector.create_dataset({"scheme": "myscheme", "arg1": "value_1"}) == Dataset( + "myscheme://value_1/default" + ) + assert collector.create_dataset( + {"scheme": "myscheme", "arg1": "value_1", "arg2": "value_2"} + ) == Dataset("myscheme://value_1/value_2") + + def test_collected_datasets(self): + collector = HookLineageCollector() + inputs = [(MagicMock(spec=Dataset), MagicMock())] + outputs = [(MagicMock(spec=Dataset), MagicMock())] + collector.inputs = inputs + collector.outputs = outputs + + hook_lineage = collector.collected_datasets + assert hook_lineage.inputs == inputs + assert hook_lineage.outputs == outputs + + def test_has_collected(self): + collector = HookLineageCollector() + assert not collector.has_collected + + collector.inputs = [MagicMock(spec=Dataset), MagicMock()] + assert collector.has_collected + + +@pytest.mark.parametrize( + "has_readers, expected_class", + [ + (True, HookLineageCollector), + (False, NoOpCollector), + ], +) +@patch("airflow.lineage.hook.ProvidersManager") +def test_get_hook_lineage_collector(mock_providers_manager, has_readers, expected_class): + # reset global variable + hook._hook_lineage_collector = None + if has_readers: + mock_providers_manager.return_value.hook_lineage_readers = [MagicMock()] + else: + mock_providers_manager.return_value.hook_lineage_readers = [] + assert isinstance(get_hook_lineage_collector(), expected_class) + assert get_hook_lineage_collector() is get_hook_lineage_collector()