Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowed wait_socket_readable/writable to accept a file descriptor #824

Merged
merged 18 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

- Fixed a misleading ``ValueError`` in the context of DNS failures
(`#815 <https://github.com/agronholm/anyio/issues/815>`_; PR by @graingert)
- Allowed ``wait_socket_readable`` and ``wait_socket_writable`` to accept a socket
file descriptor
agronholm marked this conversation as resolved.
Show resolved Hide resolved

**4.6.2**

Expand Down
20 changes: 15 additions & 5 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
Optional,
TypeVar,
Expand Down Expand Up @@ -99,6 +100,9 @@
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

if TYPE_CHECKING:
from _typeshed import HasFileno

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
Expand Down Expand Up @@ -1718,8 +1722,8 @@ async def send(self, item: bytes) -> None:
return


_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
_read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events")


#
Expand Down Expand Up @@ -2671,14 +2675,17 @@ async def getnameinfo(
return await get_running_loop().getnameinfo(sockaddr, flags)

@classmethod
async def wait_socket_readable(cls, sock: socket.socket) -> None:
async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
except LookupError:
read_events = {}
_read_events.set(read_events)

if not isinstance(sock, int):
sock = sock.fileno()

if read_events.get(sock):
raise BusyResourceError("reading from") from None

Expand All @@ -2698,20 +2705,23 @@ async def wait_socket_readable(cls, sock: socket.socket) -> None:
raise ClosedResourceError

@classmethod
async def wait_socket_writable(cls, sock: socket.socket) -> None:
async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
except LookupError:
write_events = {}
_write_events.set(write_events)

if not isinstance(sock, int):
sock = sock.fileno()

if write_events.get(sock):
raise BusyResourceError("writing to") from None

loop = get_running_loop()
event = write_events[sock] = asyncio.Event()
agronholm marked this conversation as resolved.
Show resolved Hide resolved
loop.add_writer(sock.fileno(), event.set)
loop.add_writer(sock, event.set)
try:
await event.wait()
finally:
Expand Down
8 changes: 6 additions & 2 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
Generic,
NoReturn,
Expand Down Expand Up @@ -80,6 +81,9 @@
from ..abc._eventloop import AsyncBackend, StrOrBytesPath
from ..streams.memory import MemoryObjectSendStream

if TYPE_CHECKING:
from _typeshed import HasFileno

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
Expand Down Expand Up @@ -1260,7 +1264,7 @@ async def getnameinfo(
return await trio.socket.getnameinfo(sockaddr, flags)

@classmethod
async def wait_socket_readable(cls, sock: socket.socket) -> None:
async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
try:
await wait_readable(sock)
except trio.ClosedResourceError as exc:
Expand All @@ -1269,7 +1273,7 @@ async def wait_socket_readable(cls, sock: socket.socket) -> None:
raise BusyResourceError("reading from") from None

@classmethod
async def wait_socket_writable(cls, sock: socket.socket) -> None:
async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
try:
await wait_writable(sock)
except trio.ClosedResourceError as exc:
Expand Down
15 changes: 10 additions & 5 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ipaddress import IPv6Address, ip_address
from os import PathLike, chmod
from socket import AddressFamily, SocketKind
from typing import Any, Literal, cast, overload
from typing import TYPE_CHECKING, Any, Literal, cast, overload

from .. import to_thread
from ..abc import (
Expand All @@ -31,6 +31,11 @@
from ._synchronization import Event
from ._tasks import create_task_group, move_on_after

if TYPE_CHECKING:
from _typeshed import HasFileno
else:
HasFileno = object

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

Expand Down Expand Up @@ -591,7 +596,7 @@ def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str
return get_async_backend().getnameinfo(sockaddr, flags)


def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
def wait_socket_readable(sock: HasFileno | int) -> Awaitable[None]:
"""
Wait until the given socket has data to be read.

Expand All @@ -601,7 +606,7 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!

:param sock: a socket object
:param sock: a socket object or its file descriptor
:raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
socket to become readable
:raises ~anyio.BusyResourceError: if another task is already waiting for the socket
Expand All @@ -611,7 +616,7 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
return get_async_backend().wait_socket_readable(sock)


def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
def wait_socket_writable(sock: HasFileno | int) -> Awaitable[None]:
"""
Wait until the given socket can be written to.

Expand All @@ -621,7 +626,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!

:param sock: a socket object
:param sock: a socket object or its file descriptor
:raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
socket to become writable
:raises ~anyio.BusyResourceError: if another task is already waiting for the socket
Expand Down
6 changes: 4 additions & 2 deletions src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from typing_extensions import TypeAlias

if TYPE_CHECKING:
from _typeshed import HasFileno

from .._core._synchronization import CapacityLimiter, Event, Lock, Semaphore
from .._core._tasks import CancelScope
from .._core._testing import TaskInfo
Expand Down Expand Up @@ -333,12 +335,12 @@ async def getnameinfo(

@classmethod
@abstractmethod
async def wait_socket_readable(cls, sock: socket) -> None:
async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
pass

@classmethod
@abstractmethod
async def wait_socket_writable(cls, sock: socket) -> None:
async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
pass

@classmethod
Expand Down
45 changes: 43 additions & 2 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from socket import AddressFamily
from ssl import SSLContext, SSLError
from threading import Thread
from typing import Any, NoReturn, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, NoReturn, TypeVar, cast

import psutil
import pytest
Expand Down Expand Up @@ -46,6 +46,8 @@
getnameinfo,
move_on_after,
wait_all_tasks_blocked,
wait_socket_readable,
wait_socket_writable,
)
from anyio.abc import (
IPSockAddrType,
Expand All @@ -60,7 +62,8 @@
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

from typing import Literal
if TYPE_CHECKING:
from _typeshed import HasFileno

AnyIPAddressFamily = Literal[
AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
Expand Down Expand Up @@ -1849,3 +1852,41 @@ async def test_connect_tcp_getaddrinfo_context() -> None:
pass

assert exc_info.value.__context__ is None


@pytest.mark.parametrize("socket_type", ["socket", "fd"])
@pytest.mark.parametrize("event", ["readable", "writable"])
async def test_wait_socket(
anyio_backend_name: str, event: str, socket_type: str
) -> None:
if anyio_backend_name == "asyncio" and sys.platform == "win32":
import asyncio

policy = asyncio.get_event_loop_policy()
if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")

wait_socket = wait_socket_readable if event == "readable" else wait_socket_writable

def client(port: int) -> None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect(("127.0.0.1", port))
sock.sendall(b"Hello, world")

with move_on_after(0.1):
agronholm marked this conversation as resolved.
Show resolved Hide resolved
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
port = sock.getsockname()[1]
sock.listen()
thread = Thread(target=client, args=(port,), daemon=True)
thread.start()
conn, addr = sock.accept()
with conn:
sock_or_fd: HasFileno | int = (
conn.fileno() if socket_type == "fd" else conn
)
await wait_socket(sock_or_fd)
socket_ready = True

assert socket_ready
thread.join()
agronholm marked this conversation as resolved.
Show resolved Hide resolved