From 89b6207767ecf287a07f44686ca30cc1d6e58aaa Mon Sep 17 00:00:00 2001 From: Omair Khan Date: Tue, 10 Mar 2020 21:18:43 +0530 Subject: [PATCH] [AIRFLOW-4438] Add Gzip compression to S3_hook - Added bool parameter gzip to load_file to s3_hook - Tested the load_file with load_file_gzip unittest - Updated the load_file docstring to reflect the extra parameter --- airflow/providers/amazon/aws/hooks/s3.py | 13 ++++++++++++- tests/providers/amazon/aws/hooks/test_s3.py | 11 +++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 7bf62743775f68..b0b4d179479144 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -21,8 +21,10 @@ Interact with AWS S3, using the boto3 library. """ import fnmatch +import gzip as gz import io import re +import shutil from functools import wraps from inspect import signature from tempfile import NamedTemporaryFile @@ -425,7 +427,8 @@ def load_file(self, key, bucket_name=None, replace=False, - encrypt=False): + encrypt=False, + gzip=False): """ Loads a local file to S3 @@ -442,6 +445,8 @@ def load_file(self, :param encrypt: If True, the file will be encrypted on the server-side by S3 and will be stored in an encrypted form while at rest in S3. :type encrypt: bool + :param gzip: If True, the file will be compressed locally + :type gzip: bool """ if not replace and self.check_for_key(key, bucket_name): @@ -450,6 +455,12 @@ def load_file(self, extra_args = {} if encrypt: extra_args['ServerSideEncryption'] = "AES256" + if gzip: + filename_gz = filename.name + '.gz' + with open(filename.name, 'rb') as f_in: + with gz.open(filename_gz, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + filename = filename_gz client = self.get_conn() client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args) diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index a1a7cacbdd532e..f2d4c8fd284136 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. # + +import gzip as gz import tempfile from unittest.mock import Mock @@ -244,6 +246,15 @@ def test_load_fileobj(self, s3_bucket): resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member assert resource.get()['Body'].read() == b'Content' + def test_load_file_gzip(self, s3_bucket): + hook = S3Hook() + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(b"Content") + temp_file.seek(0) + hook.load_file(temp_file, "my_key", s3_bucket, gzip=True) + resource = boto3.resource('s3').Object(s3_bucket, 'my_key') # pylint: disable=no-member + assert gz.decompress(resource.get()['Body'].read()) == b'Content' + @mock.patch.object(S3Hook, 'get_connection', return_value=Connection(schema='test_bucket')) def test_provide_bucket_name(self, mock_get_connection):