Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aiohttp fixed bug with HEAD requests and connection reuse #55

Merged
merged 1 commit into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions elastic_transport/_node/_http_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .._compat import get_running_loop, warn_stacklevel
from .._exceptions import ConnectionError, ConnectionTimeout, SecurityWarning, TlsError
from .._models import ApiResponseMeta, HttpHeaders, NodeConfig
from ..client_utils import DEFAULT, DefaultType, client_meta_version, resolve_default
from ..client_utils import DEFAULT, DefaultType, client_meta_version
from ._base import (
BUILTIN_EXCEPTIONS,
DEFAULT_CA_CERTS,
Expand All @@ -42,9 +42,14 @@

_AIOHTTP_AVAILABLE = True
_AIOHTTP_META_VERSION = client_meta_version(aiohttp.__version__)
_AIOHTTP_SEMVER_VERSION = tuple(int(x) for x in aiohttp.__version__.split(".")[:3])

# See aio-libs/aiohttp#1769 and #5012
_AIOHTTP_FIXED_HEAD_BUG = _AIOHTTP_SEMVER_VERSION >= (3, 7, 0)
except ImportError: # pragma: nocover
_AIOHTTP_AVAILABLE = False
_AIOHTTP_META_VERSION = ""
_AIOHTTP_FIXED_HEAD_BUG = False


class AiohttpHttpNode(BaseAsyncNode):
Expand Down Expand Up @@ -117,22 +122,27 @@ async def perform_request( # type: ignore[override]
headers: Optional[HttpHeaders] = None,
request_timeout: Union[DefaultType, Optional[float]] = DEFAULT,
) -> Tuple[ApiResponseMeta, bytes]:
global _AIOHTTP_FIXED_HEAD_BUG
if self.session is None:
self._create_aiohttp_session()
assert self.session is not None

url = self.base_url + target

# There is a bug in aiohttp that disables the re-use
is_head = False
# There is a bug in aiohttp<3.7 that disables the re-use
# of the connection in the pool when method=HEAD.
# See: aio-libs/aiohttp#1769
is_head = False
if method == "HEAD":
if method == "HEAD" and not _AIOHTTP_FIXED_HEAD_BUG:
method = "GET"
is_head = True

# total=0 means no timeout for aiohttp
resolved_timeout = resolve_default(request_timeout, self.config.request_timeout)
resolved_timeout: Optional[float] = (
self.config.request_timeout
if request_timeout is DEFAULT
else request_timeout
)
aiohttp_timeout = aiohttp.ClientTimeout(
total=resolved_timeout if resolved_timeout is not None else 0
)
Expand Down
32 changes: 31 additions & 1 deletion tests/node/test_http_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ async def __aexit__(self, *_, **__):
pass

async def read(self):
return response_body
return response_body if args[0] != "HEAD" else b""

async def release(self):
return None

dummy_response = DummyResponse()
dummy_response.headers = CIMultiDict()
Expand Down Expand Up @@ -250,6 +253,33 @@ async def test_merge_headers(self):
"user-agent": DEFAULT_USER_AGENT,
}

@pytest.mark.parametrize("aiohttp_fixed_head_bug", [True, False])
async def test_head_workaround(self, aiohttp_fixed_head_bug):
from elastic_transport._node import _http_aiohttp

prev = _http_aiohttp._AIOHTTP_FIXED_HEAD_BUG
try:
_http_aiohttp._AIOHTTP_FIXED_HEAD_BUG = aiohttp_fixed_head_bug

node = await self._get_mock_node(
NodeConfig(
scheme="https",
host="localhost",
port=443,
)
)
resp, data = await node.perform_request("HEAD", "/anything")

method, url = node.session.request.call_args[0]
assert method == "HEAD" if aiohttp_fixed_head_bug else "GET"
assert url == "https://localhost:443/anything"

assert resp.status == 200
assert data == b""

finally:
_http_aiohttp._AIOHTTP_FIXED_HEAD_BUG = prev


async def test_ssl_assert_fingerprint(httpbin_cert_fingerprint):
with warnings.catch_warnings(record=True) as w:
Expand Down