Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid creating a task to do DNS resolution if there is no throttle #8163

Merged
merged 16 commits into from
Feb 20, 2024
Merged
5 changes: 5 additions & 0 deletions CHANGES/8163.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Improved the not-throttled DNS resolution performance
bdraco marked this conversation as resolved.
Show resolved Hide resolved
-- by :user:`bdraco`.

This is achieverd by avoiding an :mod:`asyncio` task creation
bdraco marked this conversation as resolved.
Show resolved Hide resolved
in this case.
50 changes: 36 additions & 14 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ def clear_dns_cache(
async def _resolve_host(
self, host: str, port: int, traces: Optional[List["Trace"]] = None
) -> List[Dict[str, Any]]:
"""Resolve host and return list of addresses."""
if is_ip_address(host):
return [
{
Expand All @@ -840,8 +841,7 @@ async def _resolve_host(
return res

key = (host, port)

if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)):
if key in self._cached_hosts and not self._cached_hosts.expired(key):
# get result early, before any await (#4014)
result = self._cached_hosts.next_addrs(key)

Expand All @@ -850,6 +850,39 @@ async def _resolve_host(
await trace.send_dns_cache_hit(host)
return result

#
# If multiple connectors are resolving the same host, we wait
# for the first one to resolve and then use the result for all of them.
# We use a throttle event to ensure that we only resolve the host once
# and then use the result for all the waiters.
#
# In this case we need to create a task to ensure that we can shield
# the task from cancellation as cancelling this lookup should not cancel
# the underlying lookup or else the cancel event will get broadcast to
# all the waiters across all connections.
#
resolved_host_task = asyncio.create_task(
self._resolve_host_with_throttle(key, host, port, traces)
)
try:
return await asyncio.shield(resolved_host_task)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
bdraco marked this conversation as resolved.
Show resolved Hide resolved
with suppress(Exception, asyncio.CancelledError):
fut.result()

resolved_host_task.add_done_callback(drop_exception)
raise

async def _resolve_host_with_throttle(
self,
key: Tuple[str, int],
host: str,
port: int,
traces: Optional[List["Trace"]],
) -> List[Dict[str, Any]]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
# get event early, before any await (#4014)
event = self._throttle_dns_events[key]
Expand Down Expand Up @@ -1136,22 +1169,11 @@ async def _create_direct_connection(
host = host.rstrip(".") + "."
port = req.port
assert port is not None
host_resolved = asyncio.ensure_future(
self._resolve_host(host, port, traces=traces), loop=self._loop
)
try:
# Cancelling this lookup should not cancel the underlying lookup
# or else the cancel event will get broadcast to all the waiters
# across all connections.
hosts = await asyncio.shield(host_resolved)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

host_resolved.add_done_callback(drop_exception)
raise
hosts = await self._resolve_host(host, port, traces=traces)
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
raise
Expand Down
6 changes: 6 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ async def test_tcp_connector_dns_throttle_requests(
loop.create_task(conn._resolve_host("localhost", 8080))
loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
await asyncio.sleep(0)
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)


Expand All @@ -1032,6 +1033,9 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop: Any) -
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
await asyncio.sleep(0)
bdraco marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.sleep(0)
await asyncio.sleep(0)
assert r1.exception() == e
assert r2.exception() == e

Expand All @@ -1045,6 +1049,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
loop.create_task(conn._resolve_host("localhost", 8080))
f = loop.create_task(conn._resolve_host("localhost", 8080))

await asyncio.sleep(0)
await asyncio.sleep(0)
await conn.close()

Expand Down Expand Up @@ -1212,6 +1217,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
await asyncio.sleep(0)
await asyncio.sleep(0)
on_dns_cache_hit.assert_called_once_with(
session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost")
)
Expand Down
Loading