Skip to content

Commit

Permalink
openlineage: add support for hook lineage for S3Hook
Browse files Browse the repository at this point in the history
Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski committed Jul 16, 2024
1 parent 6366204 commit b6508e4
Show file tree
Hide file tree
Showing 14 changed files with 200 additions and 28 deletions.
4 changes: 2 additions & 2 deletions airflow/lineage/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ class NoOpCollector(HookLineageCollector):
It is used when you want to disable lineage collection.
"""

def add_input_dataset(self, *_):
def add_input_dataset(self, *_, **__):
pass

def add_output_dataset(self, *_):
def add_output_dataset(self, *_, **__):
pass

@property
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/amazon/aws/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
23 changes: 23 additions & 0 deletions airflow/providers/amazon/aws/datasets/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.datasets import Dataset


def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset:
return Dataset(uri=f"s3://{bucket}/{key}", extra=extra)
27 changes: 26 additions & 1 deletion airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from urllib.parse import urlsplit
from uuid import uuid4

from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject

Expand Down Expand Up @@ -1111,6 +1113,12 @@ def load_file(

client = self.get_conn()
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)
get_hook_lineage_collector().add_input_dataset(
context=self, scheme="file", dataset_kwargs={"path": filename}
)
get_hook_lineage_collector().add_output_dataset(
context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
)

@unify_bucket_name_and_key
@provide_bucket_name
Expand Down Expand Up @@ -1251,6 +1259,10 @@ def _upload_file_obj(
ExtraArgs=extra_args,
Config=self.transfer_config,
)
# No input because file_obj can be anything - handle in calling function if possible
get_hook_lineage_collector().add_output_dataset(
context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
)

def copy_object(
self,
Expand Down Expand Up @@ -1306,6 +1318,12 @@ def copy_object(
response = self.get_conn().copy_object(
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
)
get_hook_lineage_collector().add_input_dataset(
context=self, scheme="s3", dataset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key}
)
get_hook_lineage_collector().add_output_dataset(
context=self, scheme="s3", dataset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key}
)
return response

@provide_bucket_name
Expand Down Expand Up @@ -1425,6 +1443,11 @@ def download_file(

file_path.parent.mkdir(exist_ok=True, parents=True)

get_hook_lineage_collector().add_output_dataset(
context=self,
scheme="file",
dataset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()},
)
file = open(file_path, "wb")
else:
file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore
Expand All @@ -1435,7 +1458,9 @@ def download_file(
ExtraArgs=self.extra_args,
Config=self.transfer_config,
)

get_hook_lineage_collector().add_input_dataset(
context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
)
return file.name

def generate_presigned_url(
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ sensors:
dataset-uris:
- schemes: [s3]
handler: null
factory: airflow.providers.amazon.aws.datasets.s3.create_dataset

filesystems:
- airflow.providers.amazon.aws.fs.s3
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/common/compat/lineage/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ class NoOpCollector:
It is used when you want to disable lineage collection.
"""

def add_input_dataset(self, *_):
def add_input_dataset(self, *_, **__):
pass

def add_output_dataset(self, *_):
def add_output_dataset(self, *_, **__):
pass

return NoOpCollector()
4 changes: 2 additions & 2 deletions airflow/providers/common/io/datasets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
from airflow.datasets import Dataset


def create_dataset(*, path: str) -> Dataset:
def create_dataset(*, path: str, extra=None) -> Dataset:
# We assume that we get absolute path starting with /
return Dataset(uri=f"file://{path}")
return Dataset(uri=f"file://{path}", extra=extra)
29 changes: 15 additions & 14 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def initialize_providers_filesystems(self):
self.initialize_providers_list()
self._discover_filesystems()

@provider_info_cache("dataset_uris")
# @provider_info_cache("dataset_uris")
def initialize_providers_dataset_uri_handlers_and_factories(self):
"""Lazy initialization of provider dataset URI handlers."""
self.initialize_providers_list()
Expand Down Expand Up @@ -886,23 +886,23 @@ def _discover_dataset_uri_handlers_and_factories(self) -> None:

for provider_package, provider in self._provider_dict.items():
for handler_info in provider.data.get("dataset-uris", []):
try:
schemes = handler_info["schemes"]
handler_path = handler_info["handler"]
except KeyError:
schemes = handler_info.get("schemes")
handler_path = handler_info.get("handler")
factory_path = handler_info.get("factory")
if schemes is None:
continue
if handler_path is None:

if handler_path is not None and (
handler := _correctness_check(provider_package, handler_path, provider)
):
pass
else:
handler = normalize_noop
elif not (handler := _correctness_check(provider_package, handler_path, provider)):
continue
self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes)
factory_path = handler_info.get("factory")
if not (
factory_path is not None
and (factory := _correctness_check(provider_package, factory_path, provider))
if factory_path is not None and (
factory := _correctness_check(provider_package, factory_path, provider)
):
continue
self._dataset_factories.update((scheme, factory) for scheme in schemes)
self._dataset_factories.update((scheme, factory) for scheme in schemes)

def _discover_taskflow_decorators(self) -> None:
for name, info in self._provider_dict.items():
Expand Down Expand Up @@ -1302,6 +1302,7 @@ def filesystem_module_names(self) -> list[str]:
@property
def dataset_factories(self) -> dict[str, Callable[..., Dataset]]:
self.initialize_providers_dataset_uri_handlers_and_factories()
self.log.error(self._dataset_factories)
return self._dataset_factories

@property
Expand Down
8 changes: 4 additions & 4 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
("airflow/providers/amazon/__init__.py",),
{
"affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes "
"common.sql exasol ftp google http imap microsoft.azure "
"common.compat common.sql exasol ftp google http imap microsoft.azure "
"mongo mysql openlineage postgres salesforce ssh teradata",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
Expand All @@ -535,7 +535,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"upgrade-to-newer-dependencies": "false",
"run-amazon-tests": "true",
"parallel-test-types-list-as-string": "Always Providers[amazon] "
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,"
"Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http,"
"imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]",
"needs-mypy": "true",
"mypy-folders": "['providers']",
Expand Down Expand Up @@ -569,7 +569,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
("airflow/providers/amazon/file.py",),
{
"affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes "
"common.sql exasol ftp google http imap microsoft.azure "
"common.compat common.sql exasol ftp google http imap microsoft.azure "
"mongo mysql openlineage postgres salesforce ssh teradata",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
Expand All @@ -585,7 +585,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"run-kubernetes-tests": "false",
"upgrade-to-newer-dependencies": "false",
"parallel-test-types-list-as-string": "Always Providers[amazon] "
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,"
"Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http,"
"imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]",
"needs-mypy": "true",
"mypy-folders": "['providers']",
Expand Down
1 change: 1 addition & 0 deletions generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"cross-providers-deps": [
"apache.hive",
"cncf.kubernetes",
"common.compat",
"common.sql",
"exasol",
"ftp",
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,16 @@ def airflow_root_path() -> Path:
return Path(airflow.__path__[0]).parent


@pytest.fixture
def hook_lineage_collector():
from airflow.lineage import hook

hook._hook_lineage_collector = None
hook._hook_lineage_collector = hook.HookLineageCollector()
yield hook.get_hook_lineage_collector()
hook._hook_lineage_collector = None


# This constant is set to True if tests are run with Airflow installed from Packages rather than running
# the tests within Airflow sources. While most tests in CI are run using Airflow sources, there are
# also compatibility tests that only use `tests` package and run against installed packages of Airflow in
Expand Down
16 changes: 16 additions & 0 deletions tests/providers/amazon/aws/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
27 changes: 27 additions & 0 deletions tests/providers/amazon/aws/datasets/test_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.datasets import Dataset
from airflow.providers.amazon.aws.datasets.s3 import create_dataset


def test_create_dataset():
assert create_dataset(bucket="test-bucket", key="test-path") == Dataset(uri="s3://test-bucket/test-path")
assert create_dataset(bucket="test-bucket", key="test-dir/test-path") == Dataset(
uri="s3://test-bucket/test-dir/test-path"
)
Loading

0 comments on commit b6508e4

Please sign in to comment.