From 68d502eb007df28033771a31cfa6d6ed5a084572 Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Thu, 6 May 2021 19:40:50 -0700 Subject: [PATCH] Core raw streaming (#17920) * add raw streaming support --- .../azure-core/CLIENT_LIBRARY_DEVELOPER.md | 2 +- .../azure/core/pipeline/transport/_aiohttp.py | 36 ++++++++++++----- .../azure/core/pipeline/transport/_base.py | 4 +- .../core/pipeline/transport/_base_async.py | 6 ++- .../pipeline/transport/_requests_asyncio.py | 21 ++++++---- .../pipeline/transport/_requests_basic.py | 40 ++++++++++++++++--- .../core/pipeline/transport/_requests_trio.py | 19 ++++++--- .../test_stream_generator_async.py | 22 +++++++++- .../azure-core/tests/test_stream_generator.py | 27 +++++++++++-- 9 files changed, 138 insertions(+), 39 deletions(-) diff --git a/sdk/core/azure-core/CLIENT_LIBRARY_DEVELOPER.md b/sdk/core/azure-core/CLIENT_LIBRARY_DEVELOPER.md index fe47c1348586..32b9196088cd 100644 --- a/sdk/core/azure-core/CLIENT_LIBRARY_DEVELOPER.md +++ b/sdk/core/azure-core/CLIENT_LIBRARY_DEVELOPER.md @@ -279,7 +279,7 @@ class HttpResponse(object): def text(self, encoding=None): """Return the whole body as a string.""" - def stream_download(self, chunk_size=None, callback=None): + def stream_download(self, pipeline, **kwargs): """Generator for streaming request body data. Should be implemented by sub-classes if streaming download is supported. diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py index af553bc71b30..da70db52b5be 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py @@ -46,7 +46,6 @@ CONTENT_CHUNK_SIZE = 10 * 1024 _LOGGER = logging.getLogger(__name__) - class AioHttpTransport(AsyncHttpTransport): """AioHttp HTTP sender implementation. @@ -89,7 +88,8 @@ async def open(self): self.session = aiohttp.ClientSession( loop=self._loop, trust_env=self._use_env_settings, - cookie_jar=jar + cookie_jar=jar, + auto_decompress=False, ) if self.session is not None: await self.session.__aenter__() @@ -191,22 +191,24 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR raise ServiceResponseError(err, error=err) from err return response - class AioHttpStreamDownloadGenerator(AsyncIterator): """Streams the response body data. :param pipeline: The pipeline object :param response: The client response object. - :param block_size: block size of data sent over connection. - :type block_size: int + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the ‘content-encoding’ header. """ - def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None: + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size + self._decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) self.content_length = int(response.internal_response.headers.get('Content-Length', 0)) - self.downloaded = 0 + self._decompressor = None def __len__(self): return self.content_length @@ -216,6 +218,18 @@ async def __anext__(self): chunk = await self.response.internal_response.content.read(self.block_size) if not chunk: raise _ResponseStopIteration() + if not self._decompress: + return chunk + enc = self.response.internal_response.headers.get('Content-Encoding') + if not enc: + return chunk + enc = enc.lower() + if enc in ("gzip", "deflate"): + if not self._decompressor: + import zlib + zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS + self._decompressor = zlib.decompressobj(wbits=zlib_mode) + chunk = self._decompressor.decompress(chunk) return chunk except _ResponseStopIteration: self.response.internal_response.close() @@ -269,13 +283,15 @@ async def load_body(self) -> None: """Load in memory the body, so it could be accessible from sync methods.""" self._body = await self.internal_response.read() - def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: + def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. :param pipeline: The pipeline object - :type pipeline: azure.core.pipeline + :type pipeline: azure.core.pipeline.Pipeline + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the ‘content-encoding’ header. """ - return AioHttpStreamDownloadGenerator(pipeline, self) + return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs) def __getstate__(self): # Be sure body is loaded in memory, otherwise not pickable and let it throw diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 6e23d58888a4..589d5549c584 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -580,8 +580,8 @@ def __repr__(self): class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method - def stream_download(self, pipeline): - # type: (PipelineType) -> Iterator[bytes] + def stream_download(self, pipeline, **kwargs): + # type: (PipelineType, **Any) -> Iterator[bytes] """Generator for streaming request body data. Should be implemented by sub-classes if streaming download diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py index bfc51ef6109b..adf09cc10a51 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base_async.py @@ -124,14 +124,16 @@ class AsyncHttpResponse(_HttpResponseBase): # pylint: disable=abstract-method Allows for the asynchronous streaming of data from the response. """ - def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: + def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: """Generator for streaming response body data. Should be implemented by sub-classes if streaming download is supported. Will return an asynchronous generator. :param pipeline: The pipeline object - :type pipeline: azure.core.pipeline + :type pipeline: azure.core.pipeline.Pipeline + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the ‘content-encoding’ header. """ def parts(self) -> AsyncIterator: diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py index 90c53675f866..4f070834fa85 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_asyncio.py @@ -42,7 +42,7 @@ AsyncHttpResponse, _ResponseStopIteration, _iterate_response_content) -from ._requests_basic import RequestsTransportResponse +from ._requests_basic import RequestsTransportResponse, _read_raw_stream from ._base_requests_async import RequestsAsyncTransportBase @@ -138,17 +138,22 @@ class AsyncioStreamDownloadGenerator(AsyncIterator): :param pipeline: The pipeline object :param response: The response object. - :param generator iter_content_func: Iterator for response data. - :param int content_length: size of body in bytes. + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the ‘content-encoding’ header. """ - def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None: + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size - self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + if decompress: + self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size) self.content_length = int(response.headers.get('Content-Length', 0)) - self.downloaded = 0 def __len__(self): return self.content_length @@ -178,6 +183,6 @@ async def __anext__(self): class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore """Asynchronous streaming of data from the response. """ - def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore + def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore """Generator for streaming request body data.""" - return AsyncioStreamDownloadGenerator(pipeline, self) # type: ignore + return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) # type: ignore diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py index ed8a7382c55d..4cba767842ff 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_basic.py @@ -28,6 +28,9 @@ from typing import Iterator, Optional, Any, Union, TypeVar import urllib3 # type: ignore from urllib3.util.retry import Retry # type: ignore +from urllib3.exceptions import ( + DecodeError, ReadTimeoutError, ProtocolError +) import requests from azure.core.configuration import ConnectionConfiguration @@ -48,6 +51,25 @@ _LOGGER = logging.getLogger(__name__) +def _read_raw_stream(response, chunk_size=1): + # Special case for urllib3. + if hasattr(response.raw, 'stream'): + try: + for chunk in response.raw.stream(chunk_size, decode_content=False): + yield chunk + except ProtocolError as e: + raise requests.exceptions.ChunkedEncodingError(e) + except DecodeError as e: + raise requests.exceptions.ContentDecodingError(e) + except ReadTimeoutError as e: + raise requests.exceptions.ConnectionError(e) + else: + # Standard file-like object. + while True: + chunk = response.raw.read(chunk_size) + if not chunk: + break + yield chunk class _RequestsTransportResponseBase(_HttpResponseBase): """Base class for accessing response data. @@ -98,13 +120,21 @@ class StreamDownloadGenerator(object): :param pipeline: The pipeline object :param response: The response object. + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the ‘content-encoding’ header. """ - def __init__(self, pipeline, response): + def __init__(self, pipeline, response, **kwargs): self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size - self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + if decompress: + self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size) self.content_length = int(response.headers.get('Content-Length', 0)) def __len__(self): @@ -134,10 +164,10 @@ def __next__(self): class RequestsTransportResponse(HttpResponse, _RequestsTransportResponseBase): """Streaming of data from the response. """ - def stream_download(self, pipeline): - # type: (PipelineType) -> Iterator[bytes] + def stream_download(self, pipeline, **kwargs): + # type: (PipelineType, **Any) -> Iterator[bytes] """Generator for streaming request body data.""" - return StreamDownloadGenerator(pipeline, self) + return StreamDownloadGenerator(pipeline, self, **kwargs) class RequestsTransport(HttpTransport): diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py index 04ddd453bbf5..58fa2722e2d0 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_requests_trio.py @@ -42,7 +42,7 @@ AsyncHttpResponse, _ResponseStopIteration, _iterate_response_content) -from ._requests_basic import RequestsTransportResponse +from ._requests_basic import RequestsTransportResponse, _read_raw_stream from ._base_requests_async import RequestsAsyncTransportBase @@ -54,15 +54,22 @@ class TrioStreamDownloadGenerator(AsyncIterator): :param pipeline: The pipeline object :param response: The response object. + :keyword bool decompress: If True which is default, will attempt to decode the body based + on the ‘content-encoding’ header. """ - def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None: + def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None: self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size - self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + decompress = kwargs.pop("decompress", True) + if len(kwargs) > 0: + raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0])) + if decompress: + self.iter_content_func = self.response.internal_response.iter_content(self.block_size) + else: + self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size) self.content_length = int(response.headers.get('Content-Length', 0)) - self.downloaded = 0 def __len__(self): return self.content_length @@ -95,10 +102,10 @@ async def __anext__(self): class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore """Asynchronous streaming of data from the response. """ - def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore + def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore """Generator for streaming response data. """ - return TrioStreamDownloadGenerator(pipeline, self) + return TrioStreamDownloadGenerator(pipeline, self, **kwargs) class TrioRequestsTransport(RequestsAsyncTransportBase): # type: ignore diff --git a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py index 0a4d6017e6c4..de7bc894e42d 100644 --- a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py @@ -17,9 +17,18 @@ @pytest.mark.asyncio async def test_connection_error_response(): + class MockSession(object): + def __init__(self): + self.auto_decompress = True + + @property + def auto_decompress(self): + return self.auto_decompress + class MockTransport(AsyncHttpTransport): def __init__(self): self._count = 0 + self.session = MockSession async def __aexit__(self, exc_type, exc_val, exc_tb): pass @@ -60,7 +69,7 @@ async def __call__(self, *args, **kwargs): pipeline = AsyncPipeline(MockTransport()) http_response = AsyncHttpResponse(http_request, None) http_response.internal_response = MockInternalResponse() - stream = AioHttpStreamDownloadGenerator(pipeline, http_response) + stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False) with mock.patch('asyncio.sleep', new_callable=AsyncMock): with pytest.raises(ConnectionError): await stream.__anext__() @@ -75,6 +84,8 @@ async def test_response_streaming_error_behavior(): class FakeStreamWithConnectionError: # fake object for urllib3.response.HTTPResponse + def __init__(self): + self.total_response_size = 500 def stream(self, chunk_size, decode_content=False): assert chunk_size == block_size @@ -86,6 +97,15 @@ def stream(self, chunk_size, decode_content=False): left -= len(data) yield data + def read(self, chunk_size, decode_content=False): + assert chunk_size == block_size + if self.total_response_size > 0: + if self.total_response_size <= block_size: + raise requests.exceptions.ConnectionError() + data = b"X" * min(chunk_size, self.total_response_size) + self.total_response_size -= len(data) + return data + def close(self): pass diff --git a/sdk/core/azure-core/tests/test_stream_generator.py b/sdk/core/azure-core/tests/test_stream_generator.py index 8b30b00a4f50..c43053eeab9d 100644 --- a/sdk/core/azure-core/tests/test_stream_generator.py +++ b/sdk/core/azure-core/tests/test_stream_generator.py @@ -43,10 +43,17 @@ def __next__(self): if self._count == 0: self._count += 1 raise requests.exceptions.ConnectionError + + def stream(self, chunk_size, decode_content=False): + if self._count == 0: + self._count += 1 + raise requests.exceptions.ConnectionError + while True: + yield b"test" class MockInternalResponse(): - def iter_content(self, block_size): - return MockTransport() + def __init__(self): + self.raw = MockTransport() def close(self): pass @@ -55,7 +62,7 @@ def close(self): pipeline = Pipeline(MockTransport()) http_response = HttpResponse(http_request, None) http_response.internal_response = MockInternalResponse() - stream = StreamDownloadGenerator(pipeline, http_response) + stream = StreamDownloadGenerator(pipeline, http_response, decompress=False) with mock.patch('time.sleep', return_value=None): with pytest.raises(requests.exceptions.ConnectionError): stream.__next__() @@ -69,6 +76,8 @@ def test_response_streaming_error_behavior(): class FakeStreamWithConnectionError: # fake object for urllib3.response.HTTPResponse + def __init__(self): + self.total_response_size = 500 def stream(self, chunk_size, decode_content=False): assert chunk_size == block_size @@ -80,9 +89,19 @@ def stream(self, chunk_size, decode_content=False): left -= len(data) yield data + def read(self, chunk_size, decode_content=False): + assert chunk_size == block_size + if self.total_response_size > 0: + if self.total_response_size <= block_size: + raise requests.exceptions.ConnectionError() + data = b"X" * min(chunk_size, self.total_response_size) + self.total_response_size -= len(data) + return data + def close(self): pass + s = FakeStreamWithConnectionError() req_response.raw = FakeStreamWithConnectionError() response = RequestsTransportResponse( @@ -101,6 +120,6 @@ def mock_run(self, *args, **kwargs): transport = RequestsTransport() pipeline = Pipeline(transport) pipeline.run = mock_run - downloader = response.stream_download(pipeline) + downloader = response.stream_download(pipeline, decompress=False) with pytest.raises(requests.exceptions.ConnectionError): full_response = b"".join(downloader)