Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add correct support for compressing file-like objects #174

Merged
merged 9 commits into from
Oct 4, 2023
6 changes: 2 additions & 4 deletions ibm_cloud_sdk_core/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gzip
import json as json_import
import logging
import platform
Expand All @@ -39,6 +38,7 @@
read_external_sources,
strip_extra_slashes,
SSLHTTPAdapter,
GzipStream,
)
from .version import __version__

Expand Down Expand Up @@ -420,10 +420,8 @@ def prepare_request(
# Compress the request body if applicable
if self.get_enable_gzip_compression() and 'content-encoding' not in headers and request['data'] is not None:
headers['content-encoding'] = 'gzip'
uncompressed_data = request['data']
request_body = gzip.compress(uncompressed_data)
request['data'] = request_body
request['headers'] = headers
request['data'] = GzipStream(request['data'])

# Next, we need to process the 'files' argument to try to fill in
# any missing filenames where possible.
Expand Down
87 changes: 87 additions & 0 deletions ibm_cloud_sdk_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
# from ibm_cloud_sdk_core.authenticators import Authenticator
import datetime
import gzip
import io
import json as json_import
import re
import ssl
Expand Down Expand Up @@ -43,6 +45,91 @@ def init_poolmanager(self, connections, maxsize, block):
super().init_poolmanager(connections, maxsize, block, ssl_context=ssl_context)


class GzipStream(io.RawIOBase):
"""Compress files on the fly.

GzipStream is a helper class around the gzip library. It helps to
compress already opened files (file-like objects) on the fly, so
there is no need to read everything into the memory and call the
`compress` function on it.
The GzipFile is opened on the instance itself so it needs to act
as a file-like object.

Args:
input: the source of the data to be compressed.
It can be a file-like object, bytes or string.
"""

def __init__(self, source: Union[io.IOBase, bytes, str]) -> 'GzipStream':
self.buffer = b''

if isinstance(source, io.IOBase):
# The input is already a file-like object, use it as-is.
self.uncompressed = source
elif isinstance(source, str):
# Strings must be handled with StringIO.
self.uncompressed = io.StringIO(source)
else:
# Handle the rest as raw bytes.
self.uncompressed = io.BytesIO(source)

self.compressor = gzip.GzipFile(fileobj=self, mode='wb')

def read(self, size: int = -1) -> bytes:
"""Compresses and returns the requested size of data.

Args:
size: how many bytes to return. -1 to read and compress the whole file
"""
compressed = b''

if (size < 0) or (len(self.buffer) < size):
for raw in self.uncompressed:
# We need to encode text like streams (e.g. TextIOWrapper) to bytes.
if isinstance(raw, str):
raw = raw.encode()

self.compressor.write(raw)

# Stop compressing if we reached the max allowed size.
if 0 < size < len(self.buffer):
self.compressor.flush()
break
else:
self.compressor.close()

if size < 0:
# Return all data from the buffer.
compressed = self.buffer
self.buffer = b''
else:
# If we already have enough data in our buffer
# return the desired chunk of bytes
compressed = self.buffer[:size]
# then remove them from the buffer.
self.buffer = self.buffer[size:]

return compressed

def flush(self) -> None:
"""Not implemented."""
# Since this "pipe" sits between 2 other stream (source/read -> target/write)
# it wouldn't be worth to implemet flushing.
pass

def write(self, compressed: bytes) -> None:
"""Append the compressed data to the buffer

This happens when the target stream calls the `read` method and
that triggers the gzip "compressor".
"""
self.buffer += compressed

def close(self) -> None:
"""Closes the underlying file-like object."""
self.uncompressed.close()


def has_bad_first_or_last_char(val: str) -> bool:
"""Returns true if a string starts with any of: {," ; or ends with any of: },".

Expand Down
70 changes: 65 additions & 5 deletions test/test_base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,13 +607,13 @@ def test_gzip_compression():
service.set_enable_gzip_compression(True)
assert service.get_enable_gzip_compression()
prepped = service.prepare_request('GET', url='', data=json.dumps({"foo": "bar"}))
assert prepped['data'] == gzip.compress(b'{"foo": "bar"}')
assert prepped['data'].read() == gzip.compress(b'{"foo": "bar"}')
assert prepped['headers'].get('content-encoding') == 'gzip'

# Should return compressed data when gzip is on for non-json data
assert service.get_enable_gzip_compression()
prepped = service.prepare_request('GET', url='', data=b'rawdata')
assert prepped['data'] == gzip.compress(b'rawdata')
assert prepped['data'].read() == gzip.compress(b'rawdata')
assert prepped['headers'].get('content-encoding') == 'gzip'

# Should return compressed data when gzip is on for gzip file data
Expand All @@ -624,7 +624,7 @@ def test_gzip_compression():
with gzip.GzipFile(mode='rb', fileobj=t_f) as gz_f:
gzip_data = gz_f.read()
prepped = service.prepare_request('GET', url='', data=gzip_data)
assert prepped['data'] == gzip.compress(t_f.read())
assert prepped['data'].read() == gzip.compress(t_f.read())
assert prepped['headers'].get('content-encoding') == 'gzip'

# Should return compressed json data when gzip is on for gzip file json data
Expand All @@ -635,7 +635,7 @@ def test_gzip_compression():
with gzip.GzipFile(mode='rb', fileobj=t_f) as gz_f:
gzip_data = gz_f.read()
prepped = service.prepare_request('GET', url='', data=gzip_data)
assert prepped['data'] == gzip.compress(t_f.read())
assert prepped['data'].read() == gzip.compress(t_f.read())
assert prepped['headers'].get('content-encoding') == 'gzip'

# Should return uncompressed data when content-encoding is set
Expand All @@ -647,6 +647,66 @@ def test_gzip_compression():
assert prepped['headers'].get('content-encoding') == 'gzip'


def test_gzip_compression_file_input():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to put the new test cases into a separate function to avoid the too-many-branches linter error.

service = AnyServiceV1('2018-11-20', authenticator=NoAuthAuthenticator())
service.set_enable_gzip_compression(True)

# Should return file-like object with the compressed data when compression is on
# and the input is a file, opened for reading in binary mode.
raw_data = b'rawdata'
with tempfile.TemporaryFile(mode='w+b') as tmp_file:
tmp_file.write(raw_data)
tmp_file.seek(0)

prepped = service.prepare_request('GET', url='', data=tmp_file)
assert prepped['data'].read() == gzip.compress(raw_data)
assert prepped['headers'].get('content-encoding') == 'gzip'
assert prepped['data'].read() == b''

# Simulate the requests (urllib3) package reading method for binary files.
with tempfile.TemporaryFile(mode='w+b') as tmp_file:
tmp_file.write(raw_data)
tmp_file.seek(0)

prepped = service.prepare_request('GET', url='', data=tmp_file)
compressed = b''
for chunk in prepped['data']:
compressed += chunk

assert compressed == gzip.compress(raw_data)

# Make sure the decompression works fine.
assert gzip.decompress(compressed) == raw_data

# Should return file-like object with the compressed data when compression is on
# and the input is a file, opened for reading in text mode.
assert service.get_enable_gzip_compression()
text_data = 'textdata'
with tempfile.TemporaryFile(mode='w+') as tmp_file:
tmp_file.write(text_data)
tmp_file.seek(0)

prepped = service.prepare_request('GET', url='', data=tmp_file)
assert prepped['data'].read() == gzip.compress(text_data.encode())
assert prepped['headers'].get('content-encoding') == 'gzip'
assert prepped['data'].read() == b''

# Simulate the requests (urllib3) package reading method for text files.
with tempfile.TemporaryFile(mode='w+') as tmp_file:
tmp_file.write(text_data)
tmp_file.seek(0)

prepped = service.prepare_request('GET', url='', data=tmp_file)
compressed = b''
for chunk in prepped['data']:
compressed += chunk

assert compressed == gzip.compress(text_data.encode())

# Make sure the decompression works fine.
assert gzip.decompress(compressed).decode() == text_data


def test_gzip_compression_external():
# Should set gzip compression from external config
file_path = os.path.join(os.path.dirname(__file__), '../resources/ibm-credentials-gzip.env')
Expand All @@ -655,7 +715,7 @@ def test_gzip_compression_external():
assert service.service_url == 'https://mockurl'
assert service.get_enable_gzip_compression() is True
prepped = service.prepare_request('GET', url='', data=json.dumps({"foo": "bar"}))
assert prepped['data'] == gzip.compress(b'{"foo": "bar"}')
assert prepped['data'].read() == gzip.compress(b'{"foo": "bar"}')
assert prepped['headers'].get('content-encoding') == 'gzip'


Expand Down