From b6508e415a399d965858dd1bee182df5cf9953c3 Mon Sep 17 00:00:00 2001 From: Maciej Obuchowski Date: Tue, 16 Jul 2024 01:27:07 +0200 Subject: [PATCH] openlineage: add support for hook lineage for S3Hook Signed-off-by: Maciej Obuchowski --- airflow/lineage/hook.py | 4 +- .../providers/amazon/aws/datasets/__init__.py | 16 +++++ airflow/providers/amazon/aws/datasets/s3.py | 23 ++++++++ airflow/providers/amazon/aws/hooks/s3.py | 27 ++++++++- airflow/providers/amazon/provider.yaml | 1 + .../providers/common/compat/lineage/hook.py | 4 +- airflow/providers/common/io/datasets/file.py | 4 +- airflow/providers_manager.py | 29 +++++----- dev/breeze/tests/test_selective_checks.py | 8 +-- generated/provider_dependencies.json | 1 + tests/conftest.py | 10 ++++ .../providers/amazon/aws/datasets/__init__.py | 16 +++++ .../providers/amazon/aws/datasets/test_s3.py | 27 +++++++++ tests/providers/amazon/aws/hooks/test_s3.py | 58 ++++++++++++++++++- 14 files changed, 200 insertions(+), 28 deletions(-) create mode 100644 airflow/providers/amazon/aws/datasets/__init__.py create mode 100644 airflow/providers/amazon/aws/datasets/s3.py create mode 100644 tests/providers/amazon/aws/datasets/__init__.py create mode 100644 tests/providers/amazon/aws/datasets/test_s3.py diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index 70893516bd9bd..ee12e1624e12d 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -139,10 +139,10 @@ class NoOpCollector(HookLineageCollector): It is used when you want to disable lineage collection. """ - def add_input_dataset(self, *_): + def add_input_dataset(self, *_, **__): pass - def add_output_dataset(self, *_): + def add_output_dataset(self, *_, **__): pass @property diff --git a/airflow/providers/amazon/aws/datasets/__init__.py b/airflow/providers/amazon/aws/datasets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/amazon/aws/datasets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/amazon/aws/datasets/s3.py b/airflow/providers/amazon/aws/datasets/s3.py new file mode 100644 index 0000000000000..89889efe577b3 --- /dev/null +++ b/airflow/providers/amazon/aws/datasets/s3.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.datasets import Dataset + + +def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset: + return Dataset(uri=f"s3://{bucket}/{key}", extra=extra) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 8ca93766e2ed3..5f2c1366404eb 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -41,6 +41,8 @@ from urllib.parse import urlsplit from uuid import uuid4 +from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector + if TYPE_CHECKING: from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject @@ -1111,6 +1113,12 @@ def load_file( client = self.get_conn() client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config) + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="file", dataset_kwargs={"path": filename} + ) + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + ) @unify_bucket_name_and_key @provide_bucket_name @@ -1251,6 +1259,10 @@ def _upload_file_obj( ExtraArgs=extra_args, Config=self.transfer_config, ) + # No input because file_obj can be anything - handle in calling function if possible + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + ) def copy_object( self, @@ -1306,6 +1318,12 @@ def copy_object( response = self.get_conn().copy_object( Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs ) + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key} + ) + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key} + ) return response @provide_bucket_name @@ -1425,6 +1443,11 @@ def download_file( file_path.parent.mkdir(exist_ok=True, parents=True) + get_hook_lineage_collector().add_output_dataset( + context=self, + scheme="file", + dataset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()}, + ) file = open(file_path, "wb") else: file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore @@ -1435,7 +1458,9 @@ def download_file( ExtraArgs=self.extra_args, Config=self.transfer_config, ) - + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + ) return file.name def generate_presigned_url( diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 478aeb230c7bb..a2ead1ea5a44b 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -549,6 +549,7 @@ sensors: dataset-uris: - schemes: [s3] handler: null + factory: airflow.providers.amazon.aws.datasets.s3.create_dataset filesystems: - airflow.providers.amazon.aws.fs.s3 diff --git a/airflow/providers/common/compat/lineage/hook.py b/airflow/providers/common/compat/lineage/hook.py index 2115c992e7a41..dbdbc5bf86f4d 100644 --- a/airflow/providers/common/compat/lineage/hook.py +++ b/airflow/providers/common/compat/lineage/hook.py @@ -32,10 +32,10 @@ class NoOpCollector: It is used when you want to disable lineage collection. """ - def add_input_dataset(self, *_): + def add_input_dataset(self, *_, **__): pass - def add_output_dataset(self, *_): + def add_output_dataset(self, *_, **__): pass return NoOpCollector() diff --git a/airflow/providers/common/io/datasets/file.py b/airflow/providers/common/io/datasets/file.py index 46c7499037e06..1bc4969762b85 100644 --- a/airflow/providers/common/io/datasets/file.py +++ b/airflow/providers/common/io/datasets/file.py @@ -19,6 +19,6 @@ from airflow.datasets import Dataset -def create_dataset(*, path: str) -> Dataset: +def create_dataset(*, path: str, extra=None) -> Dataset: # We assume that we get absolute path starting with / - return Dataset(uri=f"file://{path}") + return Dataset(uri=f"file://{path}", extra=extra) diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 9e9dd4d573ddd..56463d511b83a 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -524,7 +524,7 @@ def initialize_providers_filesystems(self): self.initialize_providers_list() self._discover_filesystems() - @provider_info_cache("dataset_uris") + # @provider_info_cache("dataset_uris") def initialize_providers_dataset_uri_handlers_and_factories(self): """Lazy initialization of provider dataset URI handlers.""" self.initialize_providers_list() @@ -886,23 +886,23 @@ def _discover_dataset_uri_handlers_and_factories(self) -> None: for provider_package, provider in self._provider_dict.items(): for handler_info in provider.data.get("dataset-uris", []): - try: - schemes = handler_info["schemes"] - handler_path = handler_info["handler"] - except KeyError: + schemes = handler_info.get("schemes") + handler_path = handler_info.get("handler") + factory_path = handler_info.get("factory") + if schemes is None: continue - if handler_path is None: + + if handler_path is not None and ( + handler := _correctness_check(provider_package, handler_path, provider) + ): + pass + else: handler = normalize_noop - elif not (handler := _correctness_check(provider_package, handler_path, provider)): - continue self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes) - factory_path = handler_info.get("factory") - if not ( - factory_path is not None - and (factory := _correctness_check(provider_package, factory_path, provider)) + if factory_path is not None and ( + factory := _correctness_check(provider_package, factory_path, provider) ): - continue - self._dataset_factories.update((scheme, factory) for scheme in schemes) + self._dataset_factories.update((scheme, factory) for scheme in schemes) def _discover_taskflow_decorators(self) -> None: for name, info in self._provider_dict.items(): @@ -1302,6 +1302,7 @@ def filesystem_module_names(self) -> list[str]: @property def dataset_factories(self) -> dict[str, Callable[..., Dataset]]: self.initialize_providers_dataset_uri_handlers_and_factories() + self.log.error(self._dataset_factories) return self._dataset_factories @property diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 513edeff5f59a..cff38c9196d61 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -519,7 +519,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ("airflow/providers/amazon/__init__.py",), { "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " - "common.sql exasol ftp google http imap microsoft.azure " + "common.compat common.sql exasol ftp google http imap microsoft.azure " "mongo mysql openlineage postgres salesforce ssh teradata", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", @@ -535,7 +535,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "upgrade-to-newer-dependencies": "false", "run-amazon-tests": "true", "parallel-test-types-list-as-string": "Always Providers[amazon] " - "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http," + "Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http," "imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]", "needs-mypy": "true", "mypy-folders": "['providers']", @@ -569,7 +569,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ("airflow/providers/amazon/file.py",), { "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " - "common.sql exasol ftp google http imap microsoft.azure " + "common.compat common.sql exasol ftp google http imap microsoft.azure " "mongo mysql openlineage postgres salesforce ssh teradata", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", @@ -585,7 +585,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "run-kubernetes-tests": "false", "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "Always Providers[amazon] " - "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http," + "Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http," "imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]", "needs-mypy": "true", "mypy-folders": "['providers']", diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index bc7e5eee3a02b..585ff9b150913 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -57,6 +57,7 @@ "cross-providers-deps": [ "apache.hive", "cncf.kubernetes", + "common.compat", "common.sql", "exasol", "ftp", diff --git a/tests/conftest.py b/tests/conftest.py index 9027391575e4c..6cb74446dce8b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1326,6 +1326,16 @@ def airflow_root_path() -> Path: return Path(airflow.__path__[0]).parent +@pytest.fixture +def hook_lineage_collector(): + from airflow.lineage import hook + + hook._hook_lineage_collector = None + hook._hook_lineage_collector = hook.HookLineageCollector() + yield hook.get_hook_lineage_collector() + hook._hook_lineage_collector = None + + # This constant is set to True if tests are run with Airflow installed from Packages rather than running # the tests within Airflow sources. While most tests in CI are run using Airflow sources, there are # also compatibility tests that only use `tests` package and run against installed packages of Airflow in diff --git a/tests/providers/amazon/aws/datasets/__init__.py b/tests/providers/amazon/aws/datasets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/amazon/aws/datasets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/amazon/aws/datasets/test_s3.py b/tests/providers/amazon/aws/datasets/test_s3.py new file mode 100644 index 0000000000000..c7ffe252401e7 --- /dev/null +++ b/tests/providers/amazon/aws/datasets/test_s3.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.datasets import Dataset +from airflow.providers.amazon.aws.datasets.s3 import create_dataset + + +def test_create_dataset(): + assert create_dataset(bucket="test-bucket", key="test-path") == Dataset(uri="s3://test-bucket/test-path") + assert create_dataset(bucket="test-bucket", key="test-dir/test-path") == Dataset( + uri="s3://test-bucket/test-dir/test-path" + ) diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 6b10173d3c6ed..51f2ec2c0add2 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -31,6 +31,7 @@ from botocore.exceptions import ClientError from moto import mock_aws +from airflow.datasets import Dataset from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure @@ -388,6 +389,14 @@ def test_load_string(self, s3_bucket): resource = boto3.resource("s3").Object(s3_bucket, "my_key") assert resource.get()["Body"].read() == b"Cont\xc3\xa9nt" + def test_load_string_exposes_lineage(self, s3_bucket, hook_lineage_collector): + hook = S3Hook() + hook.load_string("Contént", "my_key", s3_bucket) + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset( + uri=f"s3://{s3_bucket}/my_key" + ) + def test_load_string_compress(self, s3_bucket): hook = S3Hook() hook.load_string("Contént", "my_key", s3_bucket, compression="gzip") @@ -970,6 +979,16 @@ def test_load_file_gzip(self, s3_bucket, tmp_path): resource = boto3.resource("s3").Object(s3_bucket, "my_key") assert gz.decompress(resource.get()["Body"].read()) == b"Content" + def test_load_file_exposes_lineage(self, s3_bucket, tmp_path, hook_lineage_collector): + hook = S3Hook() + path = tmp_path / "testfile" + path.write_text("Content") + hook.load_file(path, "my_key", s3_bucket) + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset( + uri=f"s3://{s3_bucket}/my_key" + ) + def test_load_file_acl(self, s3_bucket, tmp_path): hook = S3Hook() path = tmp_path / "testfile" @@ -1027,6 +1046,25 @@ def test_copy_object_no_acl( ACL="private", ) + @mock_aws + def test_copy_object_ol_instrumentation(self, s3_bucket, hook_lineage_collector): + mock_hook = S3Hook() + + with mock.patch.object( + S3Hook, + "get_conn", + ): + mock_hook.copy_object("my_key", "my_key3", s3_bucket, s3_bucket) + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset( + uri=f"s3://{s3_bucket}/my_key" + ) + + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset( + uri=f"s3://{s3_bucket}/my_key3" + ) + @mock_aws def test_delete_bucket_if_bucket_exist(self, s3_bucket): # assert if the bucket is created @@ -1118,7 +1156,7 @@ def test_function_with_test_key(self, test_key, bucket_name=None): assert isinstance(ctx.value, ValueError) @mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile") - def test_download_file(self, mock_temp_file, tmp_path): + def test_download_file(self, mock_temp_file, tmp_path, hook_lineage_collector): path = tmp_path / "airflow_tmp_test_s3_hook" mock_temp_file.return_value = path s3_hook = S3Hook(aws_conn_id="s3_test") @@ -1139,9 +1177,13 @@ def test_download_file(self, mock_temp_file, tmp_path): ) assert path.name == output_file + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset( + uri="s3://test_bucket/test_key" + ) @mock.patch("airflow.providers.amazon.aws.hooks.s3.open") - def test_download_file_with_preserve_name(self, mock_open, tmp_path): + def test_download_file_with_preserve_name(self, mock_open, tmp_path, hook_lineage_collector): path = tmp_path / "test.log" bucket = "test_bucket" key = f"test_key/{path.name}" @@ -1152,15 +1194,25 @@ def test_download_file_with_preserve_name(self, mock_open, tmp_path): s3_obj.key = f"s3://{bucket}/{key}" s3_obj.download_fileobj = Mock(return_value=None) s3_hook.get_key = Mock(return_value=s3_obj) + local_path = os.fspath(path.parent) s3_hook.download_file( key=key, bucket_name=bucket, - local_path=os.fspath(path.parent), + local_path=local_path, preserve_file_name=True, use_autogenerated_subdir=False, ) mock_open.assert_called_once_with(path, "wb") + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset( + uri="s3://test_bucket/test_key/test.log" + ) + + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset( + uri=f"file://{local_path}/test.log", + ) @mock.patch("airflow.providers.amazon.aws.hooks.s3.open") def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open, tmp_path):