Skip to content

Commit

Permalink
Merge pull request #2 from ellisms/s3hook2
Browse files Browse the repository at this point in the history
Cache s3 resource to reduce memory usage
  • Loading branch information
ellisms authored Mar 4, 2024
2 parents 30f7b2a + bc9c9c0 commit 95c12e6
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 34 deletions.
63 changes: 29 additions & 34 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@
from contextlib import suppress
from copy import deepcopy
from datetime import datetime
from functools import wraps
from functools import cached_property, wraps
from inspect import signature
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlsplit
from uuid import uuid4

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject

from airflow.utils.types import ArgNotSet

with suppress(ImportError):
Expand All @@ -55,22 +57,17 @@
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.helpers import chunks

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject

T = TypeVar("T", bound=Callable)

logger = logging.getLogger(__name__)


def provide_bucket_name(func: T) -> T:
def provide_bucket_name(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs) -> T:
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "bucket_name" not in bound_args.arguments:
Expand All @@ -90,10 +87,10 @@ def wrapper(*args, **kwargs) -> T:

return func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)
return wrapper


def provide_bucket_name_async(func: T) -> T:
def provide_bucket_name_async(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
function_signature = signature(func)

Expand All @@ -110,15 +107,15 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:

return await func(*bound_args.args, **bound_args.kwargs)

return cast(T, wrapper)
return wrapper


def unify_bucket_name_and_key(func: T) -> T:
def unify_bucket_name_and_key(func: Callable) -> Callable:
"""Unify bucket name and key in case no bucket name and at least a key has been passed to the function."""
function_signature = signature(func)

@wraps(func)
def wrapper(*args, **kwargs) -> T:
def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)

if "wildcard_key" in bound_args.arguments:
Expand All @@ -141,7 +138,7 @@ def wrapper(*args, **kwargs) -> T:
# if provide_bucket_name is applied first, and there's a bucket defined in conn
# then if user supplies full key, bucket in key is not respected
wrapper._unify_bucket_name_and_key_wrapped = True # type: ignore[attr-defined]
return cast(T, wrapper)
return wrapper


class S3Hook(AwsBaseHook):
Expand Down Expand Up @@ -188,6 +185,15 @@ def __init__(

super().__init__(*args, **kwargs)

@cached_property
def resource(self):
return self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)

@property
def extra_args(self):
"""Return hook's extra arguments (immutable)."""
Expand Down Expand Up @@ -307,13 +313,7 @@ def get_bucket(self, bucket_name: str | None = None) -> S3Bucket:
:param bucket_name: the name of the bucket
:return: the bucket object to the bucket name.
"""
s3_resource = self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
return s3_resource.Bucket(bucket_name)
return self.resource.Bucket(bucket_name)

@provide_bucket_name
def create_bucket(self, bucket_name: str | None = None, region_name: str | None = None) -> None:
Expand Down Expand Up @@ -943,14 +943,7 @@ def sanitize_extra_args() -> dict[str, str]:
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
}

s3_resource = self.get_session().resource(
"s3",
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
obj = s3_resource.Object(bucket_name, key)

obj = self.resource.Object(bucket_name, key)
obj.load(**sanitize_extra_args())
return obj

Expand Down Expand Up @@ -1369,10 +1362,6 @@ def download_file(
"""
Download a file from the S3 location to the local file system.
Note:
This function shadows the 'download_file' method of S3 API, but it is not the same.
If you want to use the original method from S3 API, please use 'S3Hook.get_conn().download_file()'
.. seealso::
- :external+boto3:py:meth:`S3.Object.download_fileobj`
Expand All @@ -1390,6 +1379,12 @@ def download_file(
Default: True.
:return: the file name.
"""
self.log.info(
"This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
"want to use the original method from S3 API, please call "
"'S3Hook.get_conn().download_file()'"
)

self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)

try:
Expand Down
4 changes: 4 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def test_get_conn(self):
hook = S3Hook()
assert hook.get_conn() is not None

def test_resource(self):
hook = S3Hook()
assert hook.resource is not None

def test_use_threads_default_value(self):
hook = S3Hook()
assert hook.transfer_config.use_threads is True
Expand Down

0 comments on commit 95c12e6

Please sign in to comment.