Skip to content

Commit

Permalink
Always use raw response data. (#87)
Browse files Browse the repository at this point in the history
Don't expect 'Content-Length' header for chunked downloads.

Closes #37.
Closes #49.
Closes #50.
Closes #56.
Closes #76.
  • Loading branch information
tseaver authored Aug 27, 2019
1 parent 9fdaab3 commit 2b9ffc8
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 188 deletions.
34 changes: 23 additions & 11 deletions google/resumable_media/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,18 +343,30 @@ def _process_response(self, response):
_helpers.require_status_code(
response, _ACCEPTABLE_STATUS_CODES,
self._get_status_code, callback=self._make_invalid)
content_length = _helpers.header_required(
response, u'content-length', self._get_headers,
callback=self._make_invalid)
num_bytes = int(content_length)
_, end_byte, total_bytes = get_range_info(
response, self._get_headers, callback=self._make_invalid)
headers = self._get_headers(response)
response_body = self._get_body(response)
if len(response_body) != num_bytes:
self._make_invalid()
raise common.InvalidResponse(
response, u'Response is different size than content-length',
u'Expected', num_bytes, u'Received', len(response_body))

start_byte, end_byte, total_bytes = get_range_info(
response, self._get_headers, callback=self._make_invalid)

transfer_encoding = headers.get(u'transfer-encoding')

if transfer_encoding is None:
content_length = _helpers.header_required(
response, u'content-length', self._get_headers,
callback=self._make_invalid)
num_bytes = int(content_length)
if len(response_body) != num_bytes:
self._make_invalid()
raise common.InvalidResponse(
response,
u'Response is different size than content-length',
u'Expected', num_bytes,
u'Received', len(response_body),
)
else:
# 'content-length' header not allowed with chunked encoding.
num_bytes = end_byte - start_byte + 1

# First update ``bytes_downloaded``.
self._bytes_downloaded += num_bytes
Expand Down
8 changes: 7 additions & 1 deletion google/resumable_media/requests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


_DEFAULT_RETRY_STRATEGY = common.RetryStrategy()
_SINGLE_GET_CHUNK_SIZE = 8192
# The number of seconds to wait to establish a connection
# (connect() call on socket). Avoid setting this to a multiple of 3 to not
# Align with TCP Retransmission timing. (typically 2.5-3s)
Expand Down Expand Up @@ -75,7 +76,12 @@ def _get_body(response):
Returns:
bytes: The body of the ``response``.
"""
return response.content
if response._content is False:
response._content = b''.join(
response.raw.stream(
_SINGLE_GET_CHUNK_SIZE, decode_content=False))
response._content_consumed = True
return response._content


def http_request(transport, method, url, data=None, headers=None,
Expand Down
85 changes: 13 additions & 72 deletions google/resumable_media/requests/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,12 @@
import hashlib
import logging

import urllib3.response

from google.resumable_media import _download
from google.resumable_media import common
from google.resumable_media.requests import _helpers


_LOGGER = logging.getLogger(__name__)
_SINGLE_GET_CHUNK_SIZE = 8192
_HASH_HEADER = u'x-goog-hash'
_MISSING_MD5 = u"""\
No MD5 checksum was returned from the service while downloading {}
Expand Down Expand Up @@ -117,12 +114,12 @@ def _write_to_stream(self, response):
with response:
# NOTE: This might "donate" ``md5_hash`` to the decoder and replace
# it with a ``_DoNothingHash``.
local_hash = _add_decoder(response.raw, md5_hash)
body_iter = response.iter_content(
chunk_size=_SINGLE_GET_CHUNK_SIZE, decode_unicode=False)
body_iter = response.raw.stream(
_helpers._SINGLE_GET_CHUNK_SIZE, decode_content=False)
for chunk in body_iter:
self._stream.write(chunk)
local_hash.update(chunk)
md5_hash.update(chunk)
response._content_consumed = True

if expected_md5_hash is None:
return
Expand Down Expand Up @@ -157,16 +154,15 @@ def consume(self, transport):
"""
method, url, payload, headers = self._prepare_request()
# NOTE: We assume "payload is None" but pass it along anyway.
request_kwargs = {
u'data': payload,
u'headers': headers,
u'retry_strategy': self._retry_strategy,
}
if self._stream is not None:
request_kwargs[u'stream'] = True

result = _helpers.http_request(
transport, method, url, **request_kwargs)
transport,
method,
url,
data=payload,
headers=headers,
retry_strategy=self._retry_strategy,
stream=True,
)

self._process_response(result)

Expand Down Expand Up @@ -221,7 +217,7 @@ def consume_next_chunk(self, transport):
# NOTE: We assume "payload is None" but pass it along anyway.
result = _helpers.http_request(
transport, method, url, data=payload, headers=headers,
retry_strategy=self._retry_strategy)
retry_strategy=self._retry_strategy, stream=True)
self._process_response(result)
return result

Expand Down Expand Up @@ -291,58 +287,3 @@ def update(self, unused_chunk):
Args:
unused_chunk (bytes): A chunk of data.
"""


def _add_decoder(response_raw, md5_hash):
"""Patch the ``_decoder`` on a ``urllib3`` response.
This is so that we can intercept the compressed bytes before they are
decoded.
Only patches if the content encoding is ``gzip``.
Args:
response_raw (urllib3.response.HTTPResponse): The raw response for
an HTTP request.
md5_hash (Union[_DoNothingHash, hashlib.md5]): A hash function which
will get updated when it encounters compressed bytes.
Returns:
Union[_DoNothingHash, hashlib.md5]: Either the original ``md5_hash``
if ``_decoder`` is not patched. Otherwise, returns a ``_DoNothingHash``
since the caller will no longer need to hash to decoded bytes.
"""
encoding = response_raw.headers.get(u'content-encoding', u'').lower()
if encoding != u'gzip':
return md5_hash

response_raw._decoder = _GzipDecoder(md5_hash)
return _DoNothingHash()


class _GzipDecoder(urllib3.response.GzipDecoder):
"""Custom subclass of ``urllib3`` decoder for ``gzip``-ed bytes.
Allows an MD5 hash function to see the compressed bytes before they are
decoded. This way the hash of the compressed value can be computed.
Args:
md5_hash (Union[_DoNothingHash, hashlib.md5]): A hash function which
will get updated when it encounters compressed bytes.
"""

def __init__(self, md5_hash):
super(_GzipDecoder, self).__init__()
self._md5_hash = md5_hash

def decompress(self, data):
"""Decompress the bytes.
Args:
data (bytes): The compressed bytes to be decompressed.
Returns:
bytes: The decompressed bytes from ``data``.
"""
self._md5_hash.update(data)
return super(_GzipDecoder, self).decompress(data)
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
GOOGLE_AUTH = 'google-auth >= 0.10.0'


@nox.session(python=['2,7', '3.4', '3.5', '3.6', '3.7'])
@nox.session(python=['2.7', '3.4', '3.5', '3.6', '3.7'])
def unit_tests(session):
"""Run the unit test suite."""

Expand Down
24 changes: 13 additions & 11 deletions tests/system/requests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from six.moves import http_client

from google import resumable_media
import google.resumable_media.requests as resumable_requests
import google.resumable_media.requests.download as download_mod
from google.resumable_media import requests as resumable_requests
from google.resumable_media.requests import download as download_mod
from google.resumable_media.requests import _helpers
from tests.system import utils


Expand Down Expand Up @@ -58,12 +59,10 @@
}, {
u'path': os.path.realpath(os.path.join(DATA_DIR, u'file.txt')),
u'content_type': PLAIN_TEXT,
u'checksum': u'KHRs/+ZSrc/FuuR4qz/PZQ==',
u'checksum': u'XHSHAr/SpIeZtZbjgQ4nGw==',
u'slices': (),
}, {
u'path': os.path.realpath(os.path.join(DATA_DIR, u'gzipped.txt.gz')),
u'uncompressed':
os.path.realpath(os.path.join(DATA_DIR, u'gzipped.txt')),
u'content_type': PLAIN_TEXT,
u'checksum': u'KHRs/+ZSrc/FuuR4qz/PZQ==',
u'slices': (),
Expand Down Expand Up @@ -131,13 +130,13 @@ def _get_contents_for_upload(info):


def _get_contents(info):
full_path = info.get(u'uncompressed', info[u'path'])
full_path = info[u'path']
with open(full_path, u'rb') as file_obj:
return file_obj.read()


def _get_blob_name(info):
full_path = info.get(u'uncompressed', info[u'path'])
full_path = info[u'path']
return os.path.basename(full_path)


Expand Down Expand Up @@ -184,6 +183,11 @@ def check_tombstoned(download, transport):
assert exc_info.match(u'Download has finished.')


def read_raw_content(response):
return b''.join(response.raw.stream(
_helpers._SINGLE_GET_CHUNK_SIZE, decode_content=False))


def test_download_full(add_files, authorized_transport):
for info in ALL_FILES:
actual_contents = _get_contents(info)
Expand All @@ -195,7 +199,7 @@ def test_download_full(add_files, authorized_transport):
# Consume the resource.
response = download.consume(authorized_transport)
assert response.status_code == http_client.OK
assert response.content == actual_contents
assert read_raw_content(response) == actual_contents
check_tombstoned(download, authorized_transport)


Expand All @@ -220,7 +224,6 @@ def test_download_to_stream(add_files, authorized_transport):
check_tombstoned(download, authorized_transport)


@pytest.mark.xfail # See: #76
def test_corrupt_download(add_files, corrupting_transport):
for info in ALL_FILES:
blob_name = _get_blob_name(info)
Expand Down Expand Up @@ -394,8 +397,7 @@ def consume_chunks(download, authorized_transport,
return num_responses, response


@pytest.mark.xfail # See issue #56
def test_chunked_download(add_files, authorized_transport):
def test_chunked_download_full(add_files, authorized_transport):
for info in ALL_FILES:
actual_contents = _get_contents(info)
blob_name = _get_blob_name(info)
Expand Down
15 changes: 12 additions & 3 deletions tests/unit/requests/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,21 @@ def test__get_status_code(self):

def test__get_headers(self):
headers = {u'fruit': u'apple'}
response = mock.Mock(headers=headers, spec=[u'headers'])
response = mock.Mock(headers=headers, spec=['headers'])
assert headers == _helpers.RequestsMixin._get_headers(response)

def test__get_body(self):
def test__get_body_wo_content_consumed(self):
body = b'This is the payload.'
response = mock.Mock(content=body, spec=[u'content'])
raw = mock.Mock(spec=['stream'])
raw.stream.return_value = iter([body])
response = mock.Mock(raw=raw, _content=False, spec=['raw', '_content'])
assert body == _helpers.RequestsMixin._get_body(response)
raw.stream.assert_called_once_with(
_helpers._SINGLE_GET_CHUNK_SIZE, decode_content=False)

def test__get_body_w_content_consumed(self):
body = b'This is the payload.'
response = mock.Mock(_content=body, spec=['_content'])
assert body == _helpers.RequestsMixin._get_body(response)


Expand Down
Loading

0 comments on commit 2b9ffc8

Please sign in to comment.