diff --git a/google/resumable_media/_download.py b/google/resumable_media/_download.py index b2bac98c..7958e3c0 100644 --- a/google/resumable_media/_download.py +++ b/google/resumable_media/_download.py @@ -139,6 +139,10 @@ def __init__( media_url, stream=stream, start=start, end=end, headers=headers ) self.checksum = checksum + self._bytes_downloaded = 0 + self._expected_checksum = None + self._checksum_object = None + self._object_generation = None def _prepare_request(self): """Prepare the contents of an HTTP request. diff --git a/google/resumable_media/_helpers.py b/google/resumable_media/_helpers.py index dc6fdaf1..2043d19d 100644 --- a/google/resumable_media/_helpers.py +++ b/google/resumable_media/_helpers.py @@ -22,6 +22,11 @@ import random import warnings +from urllib.parse import parse_qs +from urllib.parse import urlencode +from urllib.parse import urlsplit +from urllib.parse import urlunsplit + from google.resumable_media import common @@ -33,6 +38,7 @@ "implementation. Python 3 has a faster implementation, `google-crc32c`, " "which will be used if it is installed." ) +_GENERATION_HEADER = "x-goog-generation" _HASH_HEADER = "x-goog-hash" _MISSING_CHECKSUM = """\ No {checksum_type} checksum was returned from the service while downloading {} @@ -302,6 +308,67 @@ def _get_checksum_object(checksum_type): raise ValueError("checksum must be ``'md5'``, ``'crc32c'`` or ``None``") +def _parse_generation_header(response, get_headers): + """Parses the generation header from an ``X-Goog-Generation`` value. + + Args: + response (~requests.Response): The HTTP response object. + get_headers (callable: response->dict): returns response headers. + + Returns: + Optional[long]: The object generation from the response, if it + can be detected from the ``X-Goog-Generation`` header; otherwise, None. + """ + headers = get_headers(response) + object_generation = headers.get(_GENERATION_HEADER, None) + + if object_generation is None: + return None + else: + return int(object_generation) + + +def _get_generation_from_url(media_url): + """Retrieve the object generation query param specified in the media url. + + Args: + media_url (str): The URL containing the media to be downloaded. + + Returns: + long: The object generation from the media url if exists; otherwise, None. + """ + + _, _, _, query, _ = urlsplit(media_url) + query_params = parse_qs(query) + object_generation = query_params.get("generation", None) + + if object_generation is None: + return None + else: + return int(object_generation[0]) + + +def add_query_parameters(media_url, query_params): + """Add query parameters to a base url. + + Args: + media_url (str): The URL containing the media to be downloaded. + query_params (dict): Names and values of the query parameters to add. + + Returns: + str: URL with additional query strings appended. + """ + + if len(query_params) == 0: + return media_url + + scheme, netloc, path, query, frag = urlsplit(media_url) + params = parse_qs(query) + new_params = {**params, **query_params} + query = urlencode(new_params, doseq=True) + return urlunsplit((scheme, netloc, path, query, frag)) + + class _DoNothingHash(object): """Do-nothing hash object. diff --git a/google/resumable_media/requests/download.py b/google/resumable_media/requests/download.py index d5b130a0..58de0100 100644 --- a/google/resumable_media/requests/download.py +++ b/google/resumable_media/requests/download.py @@ -86,12 +86,22 @@ def _write_to_stream(self, response): checksum doesn't agree with server-computed checksum. """ - # `_get_expected_checksum()` may return None even if a checksum was - # requested, in which case it will emit an info log _MISSING_CHECKSUM. - # If an invalid checksum type is specified, this will raise ValueError. - expected_checksum, checksum_object = _helpers._get_expected_checksum( - response, self._get_headers, self.media_url, checksum_type=self.checksum - ) + # Retrieve the expected checksum only once for the download request, + # then compute and validate the checksum when the full download completes. + # Retried requests are range requests, and there's no way to detect + # data corruption for that byte range alone. + if self._expected_checksum is None and self._checksum_object is None: + # `_get_expected_checksum()` may return None even if a checksum was + # requested, in which case it will emit an info log _MISSING_CHECKSUM. + # If an invalid checksum type is specified, this will raise ValueError. + expected_checksum, checksum_object = _helpers._get_expected_checksum( + response, self._get_headers, self.media_url, checksum_type=self.checksum + ) + self._expected_checksum = expected_checksum + self._checksum_object = checksum_object + else: + expected_checksum = self._expected_checksum + checksum_object = self._checksum_object with response: # NOTE: In order to handle compressed streams gracefully, we try @@ -104,6 +114,7 @@ def _write_to_stream(self, response): ) for chunk in body_iter: self._stream.write(chunk) + self._bytes_downloaded += len(chunk) local_checksum_object.update(chunk) if expected_checksum is not None: @@ -150,7 +161,7 @@ def consume( ValueError: If the current :class:`Download` has already finished. """ - method, url, payload, headers = self._prepare_request() + method, _, payload, headers = self._prepare_request() # NOTE: We assume "payload is None" but pass it along anyway. request_kwargs = { "data": payload, @@ -160,10 +171,39 @@ def consume( if self._stream is not None: request_kwargs["stream"] = True + # Assign object generation if generation is specified in the media url. + if self._object_generation is None: + self._object_generation = _helpers._get_generation_from_url(self.media_url) + # Wrap the request business logic in a function to be retried. def retriable_request(): + url = self.media_url + + # To restart an interrupted download, read from the offset of last byte + # received using a range request, and set object generation query param. + if self._bytes_downloaded > 0: + _download.add_bytes_range( + self._bytes_downloaded, self.end, self._headers + ) + request_kwargs["headers"] = self._headers + + # Set object generation query param to ensure the same object content is requested. + if ( + self._object_generation is not None + and _helpers._get_generation_from_url(self.media_url) is None + ): + query_param = {"generation": self._object_generation} + url = _helpers.add_query_parameters(self.media_url, query_param) + result = transport.request(method, url, **request_kwargs) + # If a generation hasn't been specified, and this is the first response we get, let's record the + # generation. In future requests we'll specify the generation query param to avoid data races. + if self._object_generation is None: + self._object_generation = _helpers._parse_generation_header( + result, self._get_headers + ) + self._process_response(result) if self._stream is not None: @@ -223,13 +263,22 @@ def _write_to_stream(self, response): ~google.resumable_media.common.DataCorruption: If the download's checksum doesn't agree with server-computed checksum. """ - - # `_get_expected_checksum()` may return None even if a checksum was - # requested, in which case it will emit an info log _MISSING_CHECKSUM. - # If an invalid checksum type is specified, this will raise ValueError. - expected_checksum, checksum_object = _helpers._get_expected_checksum( - response, self._get_headers, self.media_url, checksum_type=self.checksum - ) + # Retrieve the expected checksum only once for the download request, + # then compute and validate the checksum when the full download completes. + # Retried requests are range requests, and there's no way to detect + # data corruption for that byte range alone. + if self._expected_checksum is None and self._checksum_object is None: + # `_get_expected_checksum()` may return None even if a checksum was + # requested, in which case it will emit an info log _MISSING_CHECKSUM. + # If an invalid checksum type is specified, this will raise ValueError. + expected_checksum, checksum_object = _helpers._get_expected_checksum( + response, self._get_headers, self.media_url, checksum_type=self.checksum + ) + self._expected_checksum = expected_checksum + self._checksum_object = checksum_object + else: + expected_checksum = self._expected_checksum + checksum_object = self._checksum_object with response: body_iter = response.raw.stream( @@ -237,6 +286,7 @@ def _write_to_stream(self, response): ) for chunk in body_iter: self._stream.write(chunk) + self._bytes_downloaded += len(chunk) checksum_object.update(chunk) response._content_consumed = True @@ -285,19 +335,47 @@ def consume( ValueError: If the current :class:`Download` has already finished. """ - method, url, payload, headers = self._prepare_request() + method, _, payload, headers = self._prepare_request() + # NOTE: We assume "payload is None" but pass it along anyway. + request_kwargs = { + "data": payload, + "headers": headers, + "timeout": timeout, + "stream": True, + } + + # Assign object generation if generation is specified in the media url. + if self._object_generation is None: + self._object_generation = _helpers._get_generation_from_url(self.media_url) # Wrap the request business logic in a function to be retried. def retriable_request(): - # NOTE: We assume "payload is None" but pass it along anyway. - result = transport.request( - method, - url, - data=payload, - headers=headers, - stream=True, - timeout=timeout, - ) + url = self.media_url + + # To restart an interrupted download, read from the offset of last byte + # received using a range request, and set object generation query param. + if self._bytes_downloaded > 0: + _download.add_bytes_range( + self._bytes_downloaded, self.end, self._headers + ) + request_kwargs["headers"] = self._headers + + # Set object generation query param to ensure the same object content is requested. + if ( + self._object_generation is not None + and _helpers._get_generation_from_url(self.media_url) is None + ): + query_param = {"generation": self._object_generation} + url = _helpers.add_query_parameters(self.media_url, query_param) + + result = transport.request(method, url, **request_kwargs) + + # If a generation hasn't been specified, and this is the first response we get, let's record the + # generation. In future requests we'll specify the generation query param to avoid data races. + if self._object_generation is None: + self._object_generation = _helpers._parse_generation_header( + result, self._get_headers + ) self._process_response(result) diff --git a/tests/unit/requests/test_download.py b/tests/unit/requests/test_download.py index df12ca44..210973d7 100644 --- a/tests/unit/requests/test_download.py +++ b/tests/unit/requests/test_download.py @@ -42,6 +42,7 @@ def test__write_to_stream_no_hash_check(self): assert ret_val is None assert stream.getvalue() == chunk1 + chunk2 + assert download._bytes_downloaded == len(chunk1 + chunk2) # Check mocks. response.__enter__.assert_called_once_with() @@ -66,6 +67,8 @@ def test__write_to_stream_with_hash_check_success(self, checksum): assert ret_val is None assert stream.getvalue() == chunk1 + chunk2 + chunk3 + assert download._bytes_downloaded == len(chunk1 + chunk2 + chunk3) + assert download._checksum_object is not None # Check mocks. response.__enter__.assert_called_once_with() @@ -273,6 +276,125 @@ def test_consume_with_headers(self): # Make sure the headers have been modified. assert headers == {"range": range_bytes} + def test_consume_gets_generation_from_url(self): + GENERATION_VALUE = 1641590104888641 + url = EXAMPLE_URL + f"&generation={GENERATION_VALUE}" + stream = io.BytesIO() + chunks = (b"up down ", b"charlie ", b"brown") + + download = download_mod.Download( + url, stream=stream, end=65536, headers=None, checksum="md5" + ) + transport = mock.Mock(spec=["request"]) + transport.request.return_value = _mock_response(chunks=chunks, headers=None) + + assert not download.finished + assert download._object_generation is None + + ret_val = download.consume(transport) + + assert download._object_generation == GENERATION_VALUE + assert ret_val is transport.request.return_value + assert stream.getvalue() == b"".join(chunks) + + called_kwargs = { + "data": None, + "headers": download._headers, + "timeout": EXPECTED_TIMEOUT, + "stream": True, + } + transport.request.assert_called_once_with("GET", url, **called_kwargs) + + def test_consume_gets_generation_from_headers(self): + GENERATION_VALUE = 1641590104888641 + stream = io.BytesIO() + chunks = (b"up down ", b"charlie ", b"brown") + + download = download_mod.Download( + EXAMPLE_URL, stream=stream, end=65536, headers=None, checksum="md5" + ) + transport = mock.Mock(spec=["request"]) + headers = {_helpers._GENERATION_HEADER: GENERATION_VALUE} + transport.request.return_value = _mock_response(chunks=chunks, headers=headers) + + assert not download.finished + assert download._object_generation is None + + ret_val = download.consume(transport) + + assert download._object_generation == GENERATION_VALUE + assert ret_val is transport.request.return_value + assert stream.getvalue() == b"".join(chunks) + + called_kwargs = { + "data": None, + "headers": download._headers, + "timeout": EXPECTED_TIMEOUT, + "stream": True, + } + transport.request.assert_called_once_with("GET", EXAMPLE_URL, **called_kwargs) + + def test_consume_w_object_generation(self): + GENERATION_VALUE = 1641590104888641 + stream = io.BytesIO() + chunks = (b"up down ", b"charlie ", b"brown") + end = 65536 + + download = download_mod.Download( + EXAMPLE_URL, stream=stream, end=end, headers=None, checksum="md5" + ) + transport = mock.Mock(spec=["request"]) + transport.request.return_value = _mock_response(chunks=chunks, headers=None) + + assert download._object_generation is None + + # Mock a retry operation with object generation retrieved and bytes already downloaded in the stream + download._object_generation = GENERATION_VALUE + offset = 256 + download._bytes_downloaded = offset + download.consume(transport) + + expected_url = EXAMPLE_URL + f"&generation={GENERATION_VALUE}" + called_kwargs = { + "data": None, + "headers": download._headers, + "timeout": EXPECTED_TIMEOUT, + "stream": True, + } + transport.request.assert_called_once_with("GET", expected_url, **called_kwargs) + range_bytes = "bytes={:d}-{:d}".format(offset, end) + assert download._headers["range"] == range_bytes + + def test_consume_w_bytes_downloaded(self): + stream = io.BytesIO() + chunks = (b"up down ", b"charlie ", b"brown") + end = 65536 + + download = download_mod.Download( + EXAMPLE_URL, stream=stream, end=end, headers=None, checksum="md5" + ) + transport = mock.Mock(spec=["request"]) + transport.request.return_value = _mock_response(chunks=chunks, headers=None) + + assert download._bytes_downloaded == 0 + + # Mock a retry operation with bytes already downloaded in the stream and checksum stored + offset = 256 + download._bytes_downloaded = offset + download._expected_checksum = None + download._checksum_object = _helpers._DoNothingHash() + download.consume(transport) + + called_kwargs = { + "data": None, + "headers": download._headers, + "timeout": EXPECTED_TIMEOUT, + "stream": True, + } + transport.request.assert_called_once_with("GET", EXAMPLE_URL, **called_kwargs) + range_bytes = "bytes={:d}-{:d}".format(offset, end) + assert download._headers["range"] == range_bytes + class TestRawDownload(object): def test__write_to_stream_no_hash_check(self): @@ -287,6 +409,7 @@ def test__write_to_stream_no_hash_check(self): assert ret_val is None assert stream.getvalue() == chunk1 + chunk2 + assert download._bytes_downloaded == len(chunk1 + chunk2) # Check mocks. response.__enter__.assert_called_once_with() @@ -313,6 +436,8 @@ def test__write_to_stream_with_hash_check_success(self, checksum): assert ret_val is None assert stream.getvalue() == chunk1 + chunk2 + chunk3 + assert download._bytes_downloaded == len(chunk1 + chunk2 + chunk3) + assert download._checksum_object is not None # Check mocks. response.__enter__.assert_called_once_with() @@ -526,6 +651,127 @@ def test_consume_with_headers(self): # Make sure the headers have been modified. assert headers == {"range": range_bytes} + def test_consume_gets_generation_from_url(self): + GENERATION_VALUE = 1641590104888641 + url = EXAMPLE_URL + f"&generation={GENERATION_VALUE}" + stream = io.BytesIO() + chunks = (b"up down ", b"charlie ", b"brown") + + download = download_mod.RawDownload( + url, stream=stream, end=65536, headers=None, checksum="md5" + ) + transport = mock.Mock(spec=["request"]) + transport.request.return_value = _mock_raw_response(chunks=chunks, headers=None) + + assert not download.finished + assert download._object_generation is None + + ret_val = download.consume(transport) + + assert download._object_generation == GENERATION_VALUE + assert ret_val is transport.request.return_value + assert stream.getvalue() == b"".join(chunks) + + called_kwargs = { + "data": None, + "headers": download._headers, + "timeout": EXPECTED_TIMEOUT, + "stream": True, + } + transport.request.assert_called_once_with("GET", url, **called_kwargs) + + def test_consume_gets_generation_from_headers(self): + GENERATION_VALUE = 1641590104888641 + stream = io.BytesIO() + chunks = (b"up down ", b"charlie ", b"brown") + + download = download_mod.RawDownload( + EXAMPLE_URL, stream=stream, end=65536, headers=None, checksum="md5" + ) + transport = mock.Mock(spec=["request"]) + headers = {_helpers._GENERATION_HEADER: GENERATION_VALUE} + transport.request.return_value = _mock_raw_response( + chunks=chunks, headers=headers + ) + + assert not download.finished + assert download._object_generation is None + + ret_val = download.consume(transport) + + assert download._object_generation == GENERATION_VALUE + assert ret_val is transport.request.return_value + assert stream.getvalue() == b"".join(chunks) + + called_kwargs = { + "data": None, + "headers": download._headers, + "timeout": EXPECTED_TIMEOUT, + "stream": True, + } + transport.request.assert_called_once_with("GET", EXAMPLE_URL, **called_kwargs) + + def test_consume_w_object_generation(self): + GENERATION_VALUE = 1641590104888641 + stream = io.BytesIO() + chunks = (b"up down ", b"charlie ", b"brown") + end = 65536 + + download = download_mod.RawDownload( + EXAMPLE_URL, stream=stream, end=end, headers=None, checksum="md5" + ) + transport = mock.Mock(spec=["request"]) + transport.request.return_value = _mock_raw_response(chunks=chunks, headers=None) + + assert download._object_generation is None + + # Mock a retry operation with object generation retrieved and bytes already downloaded in the stream + download._object_generation = GENERATION_VALUE + offset = 256 + download._bytes_downloaded = offset + download.consume(transport) + + expected_url = EXAMPLE_URL + f"&generation={GENERATION_VALUE}" + called_kwargs = { + "data": None, + "headers": download._headers, + "timeout": EXPECTED_TIMEOUT, + "stream": True, + } + transport.request.assert_called_once_with("GET", expected_url, **called_kwargs) + range_bytes = "bytes={:d}-{:d}".format(offset, end) + assert download._headers["range"] == range_bytes + + def test_consume_w_bytes_downloaded(self): + stream = io.BytesIO() + chunks = (b"up down ", b"charlie ", b"brown") + end = 65536 + + download = download_mod.RawDownload( + EXAMPLE_URL, stream=stream, end=end, headers=None, checksum="md5" + ) + transport = mock.Mock(spec=["request"]) + transport.request.return_value = _mock_raw_response(chunks=chunks, headers=None) + + assert download._bytes_downloaded == 0 + + # Mock a retry operation with bytes already downloaded in the stream and checksum stored + offset = 256 + download._bytes_downloaded = offset + download._expected_checksum = None + download._checksum_object = _helpers._DoNothingHash() + download.consume(transport) + + called_kwargs = { + "data": None, + "headers": download._headers, + "timeout": EXPECTED_TIMEOUT, + "stream": True, + } + transport.request.assert_called_once_with("GET", EXAMPLE_URL, **called_kwargs) + range_bytes = "bytes={:d}-{:d}".format(offset, end) + assert download._headers["range"] == range_bytes + class TestChunkedDownload(object): @staticmethod diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index fdb9c77f..feedeb18 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -408,6 +408,66 @@ def test_md5_multiple_matches(self): assert error.args[2] == [self.MD5_CHECKSUM, another_checksum] +class Test__parse_generation_header(object): + + GENERATION_VALUE = 1641590104888641 + + def test_empty_value(self): + headers = {} + response = _mock_response(headers=headers) + generation_header = _helpers._parse_generation_header(response, _get_headers) + assert generation_header is None + + def test_header_value(self): + headers = {_helpers._GENERATION_HEADER: self.GENERATION_VALUE} + response = _mock_response(headers=headers) + generation_header = _helpers._parse_generation_header(response, _get_headers) + assert generation_header == self.GENERATION_VALUE + + +class Test__get_generation_from_url(object): + + GENERATION_VALUE = 1641590104888641 + MEDIA_URL = ( + "https://storage.googleapis.com/storage/v1/b/my-bucket/o/my-object?alt=media" + ) + MEDIA_URL_W_GENERATION = MEDIA_URL + f"&generation={GENERATION_VALUE}" + + def test_empty_value(self): + generation = _helpers._get_generation_from_url(self.MEDIA_URL) + assert generation is None + + def test_generation_in_url(self): + generation = _helpers._get_generation_from_url(self.MEDIA_URL_W_GENERATION) + assert generation == self.GENERATION_VALUE + + +class Test__add_query_parameters(object): + def test_w_empty_list(self): + query_params = {} + MEDIA_URL = "https://storage.googleapis.com/storage/v1/b/my-bucket/o/my-object" + new_url = _helpers.add_query_parameters(MEDIA_URL, query_params) + assert new_url == MEDIA_URL + + def test_wo_existing_qs(self): + query_params = {"one": "One", "two": "Two"} + MEDIA_URL = "https://storage.googleapis.com/storage/v1/b/my-bucket/o/my-object" + expected = "&".join( + ["{}={}".format(name, value) for name, value in query_params.items()] + ) + new_url = _helpers.add_query_parameters(MEDIA_URL, query_params) + assert new_url == "{}?{}".format(MEDIA_URL, expected) + + def test_w_existing_qs(self): + query_params = {"one": "One", "two": "Two"} + MEDIA_URL = "https://storage.googleapis.com/storage/v1/b/my-bucket/o/my-object?alt=media" + expected = "&".join( + ["{}={}".format(name, value) for name, value in query_params.items()] + ) + new_url = _helpers.add_query_parameters(MEDIA_URL, query_params) + assert new_url == "{}&{}".format(MEDIA_URL, expected) + + def _mock_response(headers): return mock.Mock( headers=headers,