Skip to content

Commit

Permalink
[PR #9454/b20908e backport][3.10] Simplify DNS throttle implementation (
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Oct 10, 2024
1 parent ee87a04 commit cdf3dca
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 141 deletions.
1 change: 1 addition & 0 deletions CHANGES/9454.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Simplified DNS resolution throttling code to reduce chance of race conditions -- by :user:`bdraco`.
96 changes: 59 additions & 37 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import suppress
from http import HTTPStatus
from http.cookies import SimpleCookie
from itertools import cycle, islice
from itertools import chain, cycle, islice
from time import monotonic
from types import TracebackType
from typing import (
Expand Down Expand Up @@ -50,8 +50,14 @@
)
from .client_proto import ResponseHandler
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
from .helpers import ceil_timeout, is_ip_address, noop, sentinel
from .locks import EventResultOrError
from .helpers import (
ceil_timeout,
is_ip_address,
noop,
sentinel,
set_exception,
set_result,
)
from .resolver import DefaultResolver

try:
Expand Down Expand Up @@ -840,7 +846,9 @@ def __init__(

self._use_dns_cache = use_dns_cache
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {}
self._throttle_dns_futures: Dict[
Tuple[str, int], Set["asyncio.Future[None]"]
] = {}
self._family = family
self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr)
self._happy_eyeballs_delay = happy_eyeballs_delay
Expand All @@ -849,8 +857,8 @@ def __init__(

def close(self) -> Awaitable[None]:
"""Close all ongoing DNS calls."""
for ev in self._throttle_dns_events.values():
ev.cancel()
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
fut.cancel()

for t in self._resolve_host_tasks:
t.cancel()
Expand Down Expand Up @@ -918,18 +926,35 @@ async def _resolve_host(
await trace.send_dns_cache_hit(host)
return result

futures: Set["asyncio.Future[None]"]
#
# If multiple connectors are resolving the same host, we wait
# for the first one to resolve and then use the result for all of them.
# We use a throttle event to ensure that we only resolve the host once
# We use a throttle to ensure that we only resolve the host once
# and then use the result for all the waiters.
#
if key in self._throttle_dns_futures:
# get futures early, before any await (#4014)
futures = self._throttle_dns_futures[key]
future: asyncio.Future[None] = self._loop.create_future()
futures.add(future)
if traces:
for trace in traces:
await trace.send_dns_cache_hit(host)
try:
await future
finally:
futures.discard(future)
return self._cached_hosts.next_addrs(key)

# update dict early, before any await (#4014)
self._throttle_dns_futures[key] = futures = set()
# In this case we need to create a task to ensure that we can shield
# the task from cancellation as cancelling this lookup should not cancel
# the underlying lookup or else the cancel event will get broadcast to
# all the waiters across all connections.
#
coro = self._resolve_host_with_throttle(key, host, port, traces)
coro = self._resolve_host_with_throttle(key, host, port, futures, traces)
loop = asyncio.get_running_loop()
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send immediately
Expand Down Expand Up @@ -957,42 +982,39 @@ async def _resolve_host_with_throttle(
key: Tuple[str, int],
host: str,
port: int,
futures: Set["asyncio.Future[None]"],
traces: Optional[Sequence["Trace"]],
) -> List[ResolveResult]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
# get event early, before any await (#4014)
event = self._throttle_dns_events[key]
"""Resolve host and set result for all waiters.
This method must be run in a task and shielded from cancellation
to avoid cancelling the underlying lookup.
"""
if traces:
for trace in traces:
await trace.send_dns_cache_miss(host)
try:
if traces:
for trace in traces:
await trace.send_dns_cache_hit(host)
await event.wait()
else:
# update dict early, before any await (#4014)
self._throttle_dns_events[key] = EventResultOrError(self._loop)
await trace.send_dns_resolvehost_start(host)

addrs = await self._resolver.resolve(host, port, family=self._family)
if traces:
for trace in traces:
await trace.send_dns_cache_miss(host)
try:

if traces:
for trace in traces:
await trace.send_dns_resolvehost_start(host)

addrs = await self._resolver.resolve(host, port, family=self._family)
if traces:
for trace in traces:
await trace.send_dns_resolvehost_end(host)
await trace.send_dns_resolvehost_end(host)

self._cached_hosts.add(key, addrs)
self._throttle_dns_events[key].set()
except BaseException as e:
# any DNS exception, independently of the implementation
# is set for the waiters to raise the same exception.
self._throttle_dns_events[key].set(exc=e)
raise
finally:
self._throttle_dns_events.pop(key)
self._cached_hosts.add(key, addrs)
for fut in futures:
set_result(fut, None)
except BaseException as e:
# any DNS exception is set for the waiters to raise the same exception.
# This coro is always run in task that is shielded from cancellation so
# we should never be propagating cancellation here.
for fut in futures:
set_exception(fut, e)
raise
finally:
self._throttle_dns_futures.pop(key)

return self._cached_hosts.next_addrs(key)

Expand Down
41 changes: 0 additions & 41 deletions aiohttp/locks.py

This file was deleted.

Loading

0 comments on commit cdf3dca

Please sign in to comment.