Skip to content

Commit

Permalink
Core raw streaming (Azure#17920)
Browse files Browse the repository at this point in the history
* add raw streaming support
  • Loading branch information
xiangyan99 authored May 7, 2021
1 parent db9cde5 commit 68d502e
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 39 deletions.
2 changes: 1 addition & 1 deletion sdk/core/azure-core/CLIENT_LIBRARY_DEVELOPER.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class HttpResponse(object):
def text(self, encoding=None):
"""Return the whole body as a string."""

def stream_download(self, chunk_size=None, callback=None):
def stream_download(self, pipeline, **kwargs):
"""Generator for streaming request body data.
Should be implemented by sub-classes if streaming download
is supported.
Expand Down
36 changes: 26 additions & 10 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
CONTENT_CHUNK_SIZE = 10 * 1024
_LOGGER = logging.getLogger(__name__)


class AioHttpTransport(AsyncHttpTransport):
"""AioHttp HTTP sender implementation.
Expand Down Expand Up @@ -89,7 +88,8 @@ async def open(self):
self.session = aiohttp.ClientSession(
loop=self._loop,
trust_env=self._use_env_settings,
cookie_jar=jar
cookie_jar=jar,
auto_decompress=False,
)
if self.session is not None:
await self.session.__aenter__()
Expand Down Expand Up @@ -191,22 +191,24 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
raise ServiceResponseError(err, error=err) from err
return response


class AioHttpStreamDownloadGenerator(AsyncIterator):
"""Streams the response body data.
:param pipeline: The pipeline object
:param response: The client response object.
:param block_size: block size of data sent over connection.
:type block_size: int
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self._decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
self.content_length = int(response.internal_response.headers.get('Content-Length', 0))
self.downloaded = 0
self._decompressor = None

def __len__(self):
return self.content_length
Expand All @@ -216,6 +218,18 @@ async def __anext__(self):
chunk = await self.response.internal_response.content.read(self.block_size)
if not chunk:
raise _ResponseStopIteration()
if not self._decompress:
return chunk
enc = self.response.internal_response.headers.get('Content-Encoding')
if not enc:
return chunk
enc = enc.lower()
if enc in ("gzip", "deflate"):
if not self._decompressor:
import zlib
zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS
self._decompressor = zlib.decompressobj(wbits=zlib_mode)
chunk = self._decompressor.decompress(chunk)
return chunk
except _ResponseStopIteration:
self.response.internal_response.close()
Expand Down Expand Up @@ -269,13 +283,15 @@ async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
self._body = await self.internal_response.read()

def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.
:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline
:type pipeline: azure.core.pipeline.Pipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
return AioHttpStreamDownloadGenerator(pipeline, self)
return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs)

def __getstate__(self):
# Be sure body is loaded in memory, otherwise not pickable and let it throw
Expand Down
4 changes: 2 additions & 2 deletions sdk/core/azure-core/azure/core/pipeline/transport/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,8 @@ def __repr__(self):


class HttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
def stream_download(self, pipeline):
# type: (PipelineType) -> Iterator[bytes]
def stream_download(self, pipeline, **kwargs):
# type: (PipelineType, **Any) -> Iterator[bytes]
"""Generator for streaming request body data.
Should be implemented by sub-classes if streaming download
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,16 @@ class AsyncHttpResponse(_HttpResponseBase): # pylint: disable=abstract-method
Allows for the asynchronous streaming of data from the response.
"""

def stream_download(self, pipeline) -> AsyncIteratorType[bytes]:
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.
Should be implemented by sub-classes if streaming download
is supported. Will return an asynchronous generator.
:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline
:type pipeline: azure.core.pipeline.Pipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""

def parts(self) -> AsyncIterator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
AsyncHttpResponse,
_ResponseStopIteration,
_iterate_response_content)
from ._requests_basic import RequestsTransportResponse
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
from ._base_requests_async import RequestsAsyncTransportBase


Expand Down Expand Up @@ -138,17 +138,22 @@ class AsyncioStreamDownloadGenerator(AsyncIterator):
:param pipeline: The pipeline object
:param response: The response object.
:param generator iter_content_func: Iterator for response data.
:param int content_length: size of body in bytes.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
if decompress:
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
else:
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
self.content_length = int(response.headers.get('Content-Length', 0))
self.downloaded = 0

def __len__(self):
return self.content_length
Expand Down Expand Up @@ -178,6 +183,6 @@ async def __anext__(self):
class AsyncioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore
"""Asynchronous streaming of data from the response.
"""
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore
"""Generator for streaming request body data."""
return AsyncioStreamDownloadGenerator(pipeline, self) # type: ignore
return AsyncioStreamDownloadGenerator(pipeline, self, **kwargs) # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from typing import Iterator, Optional, Any, Union, TypeVar
import urllib3 # type: ignore
from urllib3.util.retry import Retry # type: ignore
from urllib3.exceptions import (
DecodeError, ReadTimeoutError, ProtocolError
)
import requests

from azure.core.configuration import ConnectionConfiguration
Expand All @@ -48,6 +51,25 @@

_LOGGER = logging.getLogger(__name__)

def _read_raw_stream(response, chunk_size=1):
# Special case for urllib3.
if hasattr(response.raw, 'stream'):
try:
for chunk in response.raw.stream(chunk_size, decode_content=False):
yield chunk
except ProtocolError as e:
raise requests.exceptions.ChunkedEncodingError(e)
except DecodeError as e:
raise requests.exceptions.ContentDecodingError(e)
except ReadTimeoutError as e:
raise requests.exceptions.ConnectionError(e)
else:
# Standard file-like object.
while True:
chunk = response.raw.read(chunk_size)
if not chunk:
break
yield chunk

class _RequestsTransportResponseBase(_HttpResponseBase):
"""Base class for accessing response data.
Expand Down Expand Up @@ -98,13 +120,21 @@ class StreamDownloadGenerator(object):
:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
def __init__(self, pipeline, response):
def __init__(self, pipeline, response, **kwargs):
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
if decompress:
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
else:
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
self.content_length = int(response.headers.get('Content-Length', 0))

def __len__(self):
Expand Down Expand Up @@ -134,10 +164,10 @@ def __next__(self):
class RequestsTransportResponse(HttpResponse, _RequestsTransportResponseBase):
"""Streaming of data from the response.
"""
def stream_download(self, pipeline):
# type: (PipelineType) -> Iterator[bytes]
def stream_download(self, pipeline, **kwargs):
# type: (PipelineType, **Any) -> Iterator[bytes]
"""Generator for streaming request body data."""
return StreamDownloadGenerator(pipeline, self)
return StreamDownloadGenerator(pipeline, self, **kwargs)


class RequestsTransport(HttpTransport):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
AsyncHttpResponse,
_ResponseStopIteration,
_iterate_response_content)
from ._requests_basic import RequestsTransportResponse
from ._requests_basic import RequestsTransportResponse, _read_raw_stream
from ._base_requests_async import RequestsAsyncTransportBase


Expand All @@ -54,15 +54,22 @@ class TrioStreamDownloadGenerator(AsyncIterator):
:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the ‘content-encoding’ header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None:
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
if decompress:
self.iter_content_func = self.response.internal_response.iter_content(self.block_size)
else:
self.iter_content_func = _read_raw_stream(self.response.internal_response, self.block_size)
self.content_length = int(response.headers.get('Content-Length', 0))
self.downloaded = 0

def __len__(self):
return self.content_length
Expand Down Expand Up @@ -95,10 +102,10 @@ async def __anext__(self):
class TrioRequestsTransportResponse(AsyncHttpResponse, RequestsTransportResponse): # type: ignore
"""Asynchronous streaming of data from the response.
"""
def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: # type: ignore
def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]: # type: ignore
"""Generator for streaming response data.
"""
return TrioStreamDownloadGenerator(pipeline, self)
return TrioStreamDownloadGenerator(pipeline, self, **kwargs)


class TrioRequestsTransport(RequestsAsyncTransportBase): # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@

@pytest.mark.asyncio
async def test_connection_error_response():
class MockSession(object):
def __init__(self):
self.auto_decompress = True

@property
def auto_decompress(self):
return self.auto_decompress

class MockTransport(AsyncHttpTransport):
def __init__(self):
self._count = 0
self.session = MockSession

async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
Expand Down Expand Up @@ -60,7 +69,7 @@ async def __call__(self, *args, **kwargs):
pipeline = AsyncPipeline(MockTransport())
http_response = AsyncHttpResponse(http_request, None)
http_response.internal_response = MockInternalResponse()
stream = AioHttpStreamDownloadGenerator(pipeline, http_response)
stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False)
with mock.patch('asyncio.sleep', new_callable=AsyncMock):
with pytest.raises(ConnectionError):
await stream.__anext__()
Expand All @@ -75,6 +84,8 @@ async def test_response_streaming_error_behavior():

class FakeStreamWithConnectionError:
# fake object for urllib3.response.HTTPResponse
def __init__(self):
self.total_response_size = 500

def stream(self, chunk_size, decode_content=False):
assert chunk_size == block_size
Expand All @@ -86,6 +97,15 @@ def stream(self, chunk_size, decode_content=False):
left -= len(data)
yield data

def read(self, chunk_size, decode_content=False):
assert chunk_size == block_size
if self.total_response_size > 0:
if self.total_response_size <= block_size:
raise requests.exceptions.ConnectionError()
data = b"X" * min(chunk_size, self.total_response_size)
self.total_response_size -= len(data)
return data

def close(self):
pass

Expand Down
Loading

0 comments on commit 68d502e

Please sign in to comment.