Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce s3hook memory usage #37886

Merged
merged 3 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 23 additions & 30 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@
from contextlib import suppress
from copy import deepcopy
from datetime import datetime
from functools import wraps
from functools import cached_property, wraps
from inspect import signature
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlsplit
from uuid import uuid4

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject

from airflow.utils.types import ArgNotSet

with suppress(ImportError):
Expand All @@ -55,22 +57,17 @@
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.helpers import chunks

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject

T = TypeVar("T", bound=Callable)

logger = logging.getLogger(__name__)


def provide_bucket_name(func: T) -> T:
def provide_bucket_name(func: Callable) -> Callable:
Taragolis marked this conversation as resolved.
Show resolved Hide resolved
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs) -> T:
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
Expand All @@ -90,10 +87,10 @@ def wrapper(*args, **kwargs) -> T:

return func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)
return wrapper


def provide_bucket_name_async(func: T) -> T:
def provide_bucket_name_async(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
function_signature = signature(func)

Expand All @@ -110,15 +107,15 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:

return await func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)
return wrapper


def unify_bucket_name_and_key(func: T) -> T:
def unify_bucket_name_and_key(func: Callable) -> Callable:
"""Unify bucket name and key in case no bucket name and at least a key has been passed to the function."""
function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs) -> T:
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "wildcard_key" in bound_args.arguments:
Expand All @@ -141,7 +138,7 @@ def wrapper(*args, **kwargs) -> T:
# if provide_bucket_name is applied first, and there's a bucket defined in conn
# then if user supplies full key, bucket in key is not respected
wrapper._unify_bucket_name_and_key_wrapped = True # type: ignore[attr-defined]
return cast(T, wrapper)
return wrapper


class S3Hook(AwsBaseHook):
Expand Down Expand Up @@ -188,6 +185,15 @@ def __init__(

super().__init__(*args, **kwargs)

@cached_property
def resource(self):
return self.get_session().resource(
self.service_name,
endpoint_url=self.conn_config.get_service_endpoint_url(service_name=self.service_name),
config=self.config,
verify=self.verify,
)

@property
def extra_args(self):
"""Return hook's extra arguments (immutable)."""
Expand Down Expand Up @@ -307,13 +313,7 @@ def get_bucket(self, bucket_name: str | None = None) -> S3Bucket:
:param bucket_name: the name of the bucket
:return: the bucket object to the bucket name.
"""
s3_resource = self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
return s3_resource.Bucket(bucket_name)
return self.resource.Bucket(bucket_name)

@provide_bucket_name
def create_bucket(self, bucket_name: str | None = None, region_name: str | None = None) -> None:
Expand Down Expand Up @@ -943,14 +943,7 @@ def sanitize_extra_args() -> dict[str, str]:
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
}

s3_resource = self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
obj = s3_resource.Object(bucket_name, key)

obj = self.resource.Object(bucket_name, key)
obj.load(**sanitize_extra_args())
return obj

Expand Down
4 changes: 4 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def test_get_conn(self):
hook = S3Hook()
assert hook.get_conn() is not None

def test_resource(self):
hook = S3Hook()
assert hook.resource is not None

def test_use_threads_default_value(self):
hook = S3Hook()
assert hook.transfer_config.use_threads is True
Expand Down