Skip to content

Commit

Permalink
Core decompress body (#18581)
Browse files Browse the repository at this point in the history
* decompress body

* update

* update

* update

* update

* update recorded tests

* update

* update

* update

* update

* update

* add type annotation

* update

* update

* update

* update

* update

* update

* update doc

* update

* update

* add comments

* update
  • Loading branch information
xiangyan99 authored May 13, 2021
1 parent 802c887 commit 502c702
Show file tree
Hide file tree
Showing 701 changed files with 56,247 additions and 111,641 deletions.
82 changes: 69 additions & 13 deletions sdk/core/azure-core/azure/core/pipeline/transport/_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@
# --------------------------------------------------------------------------
from typing import Any, Optional, AsyncIterator as AsyncIteratorType
from collections.abc import AsyncIterator
try:
import cchardet as chardet
except ImportError: # pragma: no cover
import chardet # type: ignore

import logging
import asyncio
import codecs
import aiohttp
from multidict import CIMultiDict
from requests.exceptions import StreamConsumedError
Expand Down Expand Up @@ -66,7 +71,7 @@ class AioHttpTransport(AsyncHttpTransport):
:dedent: 4
:caption: Asynchronous transport with aiohttp.
"""
def __init__(self, *, session=None, loop=None, session_owner=True, **kwargs):
def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner=True, **kwargs):
self._loop = loop
self._session_owner = session_owner
self.session = session
Expand Down Expand Up @@ -145,6 +150,11 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
:keyword str proxy: will define the proxy to use all the time
"""
await self.open()
try:
auto_decompress = self.session.auto_decompress # type: ignore
except AttributeError:
# auto_decompress is introduced in Python 3.7. We need this to handle Python 3.6.
auto_decompress = True

proxies = config.pop('proxies', None)
if proxies and 'proxy' not in config:
Expand All @@ -171,7 +181,7 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
timeout = config.pop('connection_timeout', self.connection_config.timeout)
read_timeout = config.pop('read_timeout', self.connection_config.read_timeout)
socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout)
result = await self.session.request(
result = await self.session.request( # type: ignore
request.method,
request.url,
headers=request.headers,
Expand All @@ -180,7 +190,9 @@ async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpR
allow_redirects=False,
**config
)
response = AioHttpTransportResponse(request, result, self.connection_config.data_block_size)
response = AioHttpTransportResponse(request, result,
self.connection_config.data_block_size,
decompress=not auto_decompress)
if not stream_response:
await response.load_body()
except aiohttp.client_exceptions.ClientResponseError as err:
Expand All @@ -196,17 +208,15 @@ class AioHttpStreamDownloadGenerator(AsyncIterator):
:param pipeline: The pipeline object
:param response: The client response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
:param bool decompress: If True which is default, will attempt to decode the body based
on the *content-encoding* header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, *, decompress=True) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self._decompress = kwargs.pop("decompress", True)
if len(kwargs) > 0:
raise TypeError("Got an unexpected keyword argument: {}".format(list(kwargs.keys())[0]))
self._decompress = decompress
self.content_length = int(response.internal_response.headers.get('Content-Length', 0))
self._decompressor = None

Expand Down Expand Up @@ -250,21 +260,41 @@ class AioHttpTransportResponse(AsyncHttpResponse):
:type aiohttp_response: aiohttp.ClientResponse object
:param block_size: block size of data sent over connection.
:type block_size: int
:param bool decompress: If True which is default, will attempt to decode the body based
on the *content-encoding* header.
"""
def __init__(self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None) -> None:
def __init__(self, request: HttpRequest,
aiohttp_response: aiohttp.ClientResponse,
block_size=None, *, decompress=True) -> None:
super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size)
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
self.status_code = aiohttp_response.status
self.headers = CIMultiDict(aiohttp_response.headers)
self.reason = aiohttp_response.reason
self.content_type = aiohttp_response.headers.get('content-type')
self._body = None
self._decompressed_body = None
self._decompress = decompress

def body(self) -> bytes:
"""Return the whole body as bytes in memory.
"""
if self._body is None:
raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.")
if not self._decompress:
return self._body
enc = self.headers.get('Content-Encoding')
if not enc:
return self._body
enc = enc.lower()
if enc in ("gzip", "deflate"):
if self._decompressed_body:
return self._decompressed_body
import zlib
zlib_mode = 16 + zlib.MAX_WBITS if enc == "gzip" else zlib.MAX_WBITS
decompressor = zlib.decompressobj(wbits=zlib_mode)
self._decompressed_body = decompressor.decompress(self._body)
return self._decompressed_body
return self._body

def text(self, encoding: Optional[str] = None) -> str:
Expand All @@ -274,10 +304,36 @@ def text(self, encoding: Optional[str] = None) -> str:
:param str encoding: The encoding to apply.
"""
# super().text detects charset based on self._body() which is compressed
# implement the decoding explicitly here
body = self.body()

ctype = self.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower()
mimetype = aiohttp.helpers.parse_mimetype(ctype)

encoding = mimetype.parameters.get("charset")
if encoding:
try:
codecs.lookup(encoding)
except LookupError:
encoding = None
if not encoding:
if mimetype.type == "application" and (
mimetype.subtype == "json" or mimetype.subtype == "rdap"
):
# RFC 7159 states that the default encoding is UTF-8.
# RFC 7483 defines application/rdap+json
encoding = "utf-8"
elif body is None:
raise RuntimeError(
"Cannot guess the encoding of a not yet read body"
)
else:
encoding = chardet.detect(body)["encoding"]
if not encoding:
encoding = self.internal_response.get_encoding()
encoding = "utf-8-sig"

return super().text(encoding)
return body.decode(encoding)

async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
Expand All @@ -289,7 +345,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline.Pipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""
return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def stream_download(self, pipeline, **kwargs) -> AsyncIteratorType[bytes]:
:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline.Pipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""

def parts(self) -> AsyncIterator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class AsyncioStreamDownloadGenerator(AsyncIterator):
:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class StreamDownloadGenerator(object):
:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""
def __init__(self, pipeline, response, **kwargs):
self.pipeline = pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TrioStreamDownloadGenerator(AsyncIterator):
:param pipeline: The pipeline object
:param response: The response object.
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the content-encoding header.
on the *content-encoding* header.
"""
def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse, **kwargs) -> None:
self.pipeline = pipeline
Expand Down
176 changes: 176 additions & 0 deletions sdk/core/azure-core/tests/async_tests/test_streaming_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# --------------------------------------------------------------------------
import os
import pytest
from azure.core import AsyncPipelineClient

@pytest.mark.asyncio
async def test_decompress_plain_no_header():
# expect plain text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=True)
content = b""
async for d in data:
content += d
decoded = content.decode('utf-8')
assert decoded == "test"

@pytest.mark.asyncio
async def test_compress_plain_no_header():
# expect plain text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=False)
content = b""
async for d in data:
content += d
decoded = content.decode('utf-8')
assert decoded == "test"

@pytest.mark.asyncio
async def test_decompress_compressed_no_header():
# expect compressed text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=True)
content = b""
async for d in data:
content += d
try:
decoded = content.decode('utf-8')
assert False
except UnicodeDecodeError:
pass

@pytest.mark.asyncio
async def test_compress_compressed_no_header():
# expect compressed text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=False)
content = b""
async for d in data:
content += d
try:
decoded = content.decode('utf-8')
assert False
except UnicodeDecodeError:
pass

@pytest.mark.asyncio
async def test_decompress_plain_header():
# expect error
import zlib
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=True)
try:
content = b""
async for d in data:
content += d
assert False
except zlib.error:
pass

@pytest.mark.asyncio
async def test_compress_plain_header():
# expect plain text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=False)
content = b""
async for d in data:
content += d
decoded = content.decode('utf-8')
assert decoded == "test"

@pytest.mark.asyncio
async def test_decompress_compressed_header():
# expect plain text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=True)
content = b""
async for d in data:
content += d
decoded = content.decode('utf-8')
assert decoded == "test"

@pytest.mark.asyncio
async def test_compress_compressed_header():
# expect compressed text
account_name = "coretests"
account_url = "https://{}.blob.core.windows.net".format(account_name)
url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name)
client = AsyncPipelineClient(account_url)
request = client.get(url)
pipeline_response = await client._pipeline.run(request, stream=True)
response = pipeline_response.http_response
data = response.stream_download(client._pipeline, decompress=False)
content = b""
async for d in data:
content += d
try:
decoded = content.decode('utf-8')
assert False
except UnicodeDecodeError:
pass
Loading

0 comments on commit 502c702

Please sign in to comment.