Skip to content

Commit

Permalink
Add requests/urllib3 work-around for intercepting gzipped bytes. (#36)
Browse files Browse the repository at this point in the history
Fixes #34.
  • Loading branch information
dhermes authored Oct 20, 2017
1 parent 47b4d65 commit b6c62d8
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 10 deletions.
62 changes: 61 additions & 1 deletion google/resumable_media/requests/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import hashlib
import logging

import urllib3.response

from google.resumable_media import _download
from google.resumable_media import common
from google.resumable_media.requests import _helpers
Expand Down Expand Up @@ -113,11 +115,14 @@ def _write_to_stream(self, response):
else:
md5_hash = hashlib.md5()
with response:
# NOTE: This might "donate" ``md5_hash`` to the decoder and replace
# it with a ``_DoNothingHash``.
local_hash = _add_decoder(response.raw, md5_hash)
body_iter = response.iter_content(
chunk_size=_SINGLE_GET_CHUNK_SIZE, decode_unicode=False)
for chunk in body_iter:
self._stream.write(chunk)
md5_hash.update(chunk)
local_hash.update(chunk)

if expected_md5_hash is None:
return
Expand Down Expand Up @@ -286,3 +291,58 @@ def update(self, unused_chunk):
Args:
unused_chunk (bytes): A chunk of data.
"""


def _add_decoder(response_raw, md5_hash):
"""Patch the ``_decoder`` on a ``urllib3`` response.
This is so that we can intercept the compressed bytes before they are
decoded.
Only patches if the content encoding is ``gzip``.
Args:
response_raw (urllib3.response.HTTPResponse): The raw response for
an HTTP request.
md5_hash (Union[_DoNothingHash, hashlib.md5]): A hash function which
will get updated when it encounters compressed bytes.
Returns:
Union[_DoNothingHash, hashlib.md5]: Either the original ``md5_hash``
if ``_decoder`` is not patched. Otherwise, returns a ``_DoNothingHash``
since the caller will no longer need to hash to decoded bytes.
"""
encoding = response_raw.headers.get(u'content-encoding', u'').lower()
if encoding != u'gzip':
return md5_hash

response_raw._decoder = _GzipDecoder(md5_hash)
return _DoNothingHash()


class _GzipDecoder(urllib3.response.GzipDecoder):
"""Custom subclass of ``urllib3`` decoder for ``gzip``-ed bytes.
Allows an MD5 hash function to see the compressed bytes before they are
decoded. This way the hash of the compressed value can be computed.
Args:
md5_hash (Union[_DoNothingHash, hashlib.md5]): A hash function which
will get updated when it encounters compressed bytes.
"""

def __init__(self, md5_hash):
super(_GzipDecoder, self).__init__()
self._md5_hash = md5_hash

def decompress(self, data):
"""Decompress the bytes.
Args:
data (bytes): The compressed bytes to be decompressed.
Returns:
bytes: The decompressed bytes from ``data``.
"""
self._md5_hash.update(data)
return super(_GzipDecoder, self).decompress(data)
15 changes: 11 additions & 4 deletions nox.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
'GOOGLE_RESUMABLE_MEDIA_BUCKET',
'GOOGLE_APPLICATION_CREDENTIALS',
)
REQUESTS = 'requests >= 2.18.0, < 3.0.0dev'
GOOGLE_AUTH = 'google-auth >= 0.10.0'


@nox.session
Expand All @@ -33,7 +35,7 @@ def unit_tests(session, python_version):
session.interpreter = 'python{}'.format(python_version)

# Install all test dependencies, then install this package in-place.
session.install('mock', 'pytest', 'pytest-cov')
session.install('mock', 'pytest', 'pytest-cov', REQUESTS)
session.install('-e', '.')

# Run py.test against the unit tests.
Expand Down Expand Up @@ -63,7 +65,11 @@ def docs(session):
# Install Sphinx and other dependencies.
session.chdir(os.path.realpath(os.path.dirname(__file__)))
session.install(
'sphinx', 'sphinx_rtd_theme', 'sphinx-docstring-typing >= 0.0.3')
'sphinx',
'sphinx_rtd_theme',
'sphinx-docstring-typing >= 0.0.3',
REQUESTS,
)
session.install('-e', '.')

# Build the docs!
Expand All @@ -82,7 +88,8 @@ def doctest(session):
'sphinx_rtd_theme',
'sphinx-docstring-typing >= 0.0.3',
'mock',
'google-auth'
GOOGLE_AUTH,
REQUESTS,
)
session.install('-e', '.')

Expand Down Expand Up @@ -142,7 +149,7 @@ def system_tests(session, python_version):

# Install all test dependencies, then install this package into the
# virutalenv's dist-packages.
session.install('mock', 'pytest', 'requests', 'google-auth >= 0.10.0')
session.install('mock', 'pytest', REQUESTS, GOOGLE_AUTH)
session.install('-e', '.')

# Run py.test against the system tests.
Expand Down
5 changes: 2 additions & 3 deletions tests/system/requests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@
u'path': os.path.realpath(os.path.join(DATA_DIR, u'file.txt.gz')),
u'uncompressed': os.path.realpath(os.path.join(DATA_DIR, u'file.txt')),
u'content_type': PLAIN_TEXT,
# NOTE: This **should** be u'KHRs/+ZSrc/FuuR4qz/PZQ=='.
u'checksum': u'XHSHAr/SpIeZtZbjgQ4nGw==',
u'checksum': u'KHRs/+ZSrc/FuuR4qz/PZQ==',
u'slices': (),
u'metadata': {
u'contentEncoding': u'gzip',
Expand Down Expand Up @@ -96,7 +95,7 @@ class CorruptingAuthorizedSession(tr_requests.AuthorizedSession):
def request(self, method, url, data=None, headers=None, **kwargs):
"""Implementation of Requests' request."""
response = tr_requests.AuthorizedSession.request(
self, method, url, data=data, headers=headers)
self, method, url, data=data, headers=headers, **kwargs)
response.headers[download_mod._HASH_HEADER] = (
u'md5={}'.format(self.EMPTY_HASH))
return response
Expand Down
53 changes: 51 additions & 2 deletions tests/unit/requests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,15 @@ def _mock_response(self, start_byte, end_byte, total_bytes,
response_headers = self._response_headers(
start_byte, end_byte, total_bytes)
return mock.Mock(
content=content, headers=response_headers, status_code=status_code,
spec=[u'content', u'headers', u'status_code'])
content=content,
headers=response_headers,
status_code=status_code,
spec=[
u'content',
u'headers',
u'status_code',
],
)

def test_consume_next_chunk_already_finished(self):
download = download_mod.ChunkedDownload(EXAMPLE_URL, 512, None)
Expand Down Expand Up @@ -355,20 +362,62 @@ def test__DoNothingHash():
assert return_value is None


class Test__add_decoder(object):

def test_non_gzipped(self):
response_raw = mock.Mock(headers={}, spec=[u'headers'])
md5_hash = download_mod._add_decoder(
response_raw, mock.sentinel.md5_hash)

assert md5_hash is mock.sentinel.md5_hash

def test_gzipped(self):
headers = {u'content-encoding': u'gzip'}
response_raw = mock.Mock(
headers=headers, spec=[u'headers', u'_decoder'])
md5_hash = download_mod._add_decoder(
response_raw, mock.sentinel.md5_hash)

assert md5_hash is not mock.sentinel.md5_hash
assert isinstance(md5_hash, download_mod._DoNothingHash)
assert isinstance(response_raw._decoder, download_mod._GzipDecoder)
assert response_raw._decoder._md5_hash is mock.sentinel.md5_hash


class Test_GzipDecoder(object):

def test_constructor(self):
decoder = download_mod._GzipDecoder(mock.sentinel.md5_hash)
assert decoder._md5_hash is mock.sentinel.md5_hash

def test_decompress(self):
md5_hash = mock.Mock(spec=['update'])
decoder = download_mod._GzipDecoder(md5_hash)

data = b'\x1f\x8b\x08\x08'
result = decoder.decompress(data)

assert result == b''
md5_hash.update.assert_called_once_with(data)


def _mock_response(status_code=http_client.OK, chunks=(), headers=None):
if headers is None:
headers = {}

if chunks:
mock_raw = mock.Mock(headers=headers, spec=[u'headers'])
response = mock.MagicMock(
headers=headers,
status_code=int(status_code),
raw=mock_raw,
spec=[
u'__enter__',
u'__exit__',
u'iter_content',
u'status_code',
u'headers',
u'raw',
],
)
# i.e. context manager returns ``self``.
Expand Down

0 comments on commit b6c62d8

Please sign in to comment.