Skip to content

Commit

Permalink
Renamed functions, deprecated old ones
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 16, 2024
1 parent cd212aa commit e60655c
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 24 deletions.
7 changes: 4 additions & 3 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

- Fixed a misleading ``ValueError`` in the context of DNS failures
(`#815 <https://github.com/agronholm/anyio/issues/815>`_; PR by @graingert)
- Allowed ``wait_socket_readable`` and ``wait_socket_writable`` to accept a socket
file descriptor (`#824 <https://github.com/agronholm/anyio/pull/824>`_)
(PR by @davidbrochart)
- Added ``wait_readable`` and ``wait_writable`` functions that accept an object with a
``.fileno()`` method or an integer handle, and deprecated ``wait_socket_readable``
and ``wait_socket_writable``.
(`#824 <https://github.com/agronholm/anyio/pull/824>`_) (PR by @davidbrochart)

**4.6.2**

Expand Down
2 changes: 2 additions & 0 deletions src/anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
from ._core._sockets import create_unix_listener as create_unix_listener
from ._core._sockets import getaddrinfo as getaddrinfo
from ._core._sockets import getnameinfo as getnameinfo
from ._core._sockets import wait_readable as wait_readable
from ._core._sockets import wait_socket_readable as wait_socket_readable
from ._core._sockets import wait_socket_writable as wait_socket_writable
from ._core._sockets import wait_writable as wait_writable
from ._core._streams import create_memory_object_stream as create_memory_object_stream
from ._core._subprocesses import open_process as open_process
from ._core._subprocesses import run_process as run_process
Expand Down
76 changes: 65 additions & 11 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,8 +1722,8 @@ async def send(self, item: bytes) -> None:
return


_read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events")
_read_events: RunVar[dict[socket.socket | int, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[socket.socket | int, asyncio.Event]] = RunVar("write_events")


#
Expand Down Expand Up @@ -2675,17 +2675,14 @@ async def getnameinfo(
return await get_running_loop().getnameinfo(sockaddr, flags)

@classmethod
async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
async def wait_socket_readable(cls, sock: socket.socket) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
except LookupError:
read_events = {}
_read_events.set(read_events)

if not isinstance(sock, int):
sock = sock.fileno()

if read_events.get(sock):
raise BusyResourceError("reading from") from None

Expand All @@ -2705,23 +2702,20 @@ async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
raise ClosedResourceError

@classmethod
async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
async def wait_socket_writable(cls, sock: socket.socket) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
except LookupError:
write_events = {}
_write_events.set(write_events)

if not isinstance(sock, int):
sock = sock.fileno()

if write_events.get(sock):
raise BusyResourceError("writing to") from None

loop = get_running_loop()
event = write_events[sock] = asyncio.Event()
loop.add_writer(sock, event.set)
loop.add_writer(sock.fileno(), event.set)
try:
await event.wait()
finally:
Expand All @@ -2734,6 +2728,66 @@ async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
if not writable:
raise ClosedResourceError

@classmethod
async def wait_readable(cls, obj: HasFileno | int) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
except LookupError:
read_events = {}
_read_events.set(read_events)

if not isinstance(obj, int):
obj = obj.fileno()

if read_events.get(obj):
raise BusyResourceError("reading from") from None

loop = get_running_loop()
event = read_events[obj] = asyncio.Event()
loop.add_reader(obj, event.set)
try:
await event.wait()
finally:
if read_events.pop(obj, None) is not None:
loop.remove_reader(obj)
readable = True
else:
readable = False

if not readable:
raise ClosedResourceError

@classmethod
async def wait_writable(cls, obj: HasFileno | int) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
except LookupError:
write_events = {}
_write_events.set(write_events)

if not isinstance(obj, int):
obj = obj.fileno()

if write_events.get(obj):
raise BusyResourceError("writing to") from None

loop = get_running_loop()
event = write_events[obj] = asyncio.Event()
loop.add_writer(obj, event.set)
try:
await event.wait()
finally:
if write_events.pop(obj, None) is not None:
loop.remove_writer(obj)
writable = True
else:
writable = False

if not writable:
raise ClosedResourceError

@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
try:
Expand Down
22 changes: 20 additions & 2 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,7 +1264,7 @@ async def getnameinfo(
return await trio.socket.getnameinfo(sockaddr, flags)

@classmethod
async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
async def wait_socket_readable(cls, sock: socket.socket) -> None:
try:
await wait_readable(sock)
except trio.ClosedResourceError as exc:
Expand All @@ -1273,14 +1273,32 @@ async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
raise BusyResourceError("reading from") from None

@classmethod
async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
async def wait_socket_writable(cls, sock: socket.socket) -> None:
try:
await wait_writable(sock)
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
except trio.BusyResourceError:
raise BusyResourceError("writing to") from None

@classmethod
async def wait_readable(cls, obj: HasFileno | int) -> None:
try:
await wait_readable(obj)
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
except trio.BusyResourceError:
raise BusyResourceError("reading from") from None

@classmethod
async def wait_writable(cls, obj: HasFileno | int) -> None:
try:
await wait_writable(obj)
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
except trio.BusyResourceError:
raise BusyResourceError("writing to") from None

@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
try:
Expand Down
75 changes: 71 additions & 4 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from os import PathLike, chmod
from socket import AddressFamily, SocketKind
from typing import TYPE_CHECKING, Any, Literal, cast, overload
from warnings import warn

from .. import to_thread
from ..abc import (
Expand Down Expand Up @@ -596,8 +597,10 @@ def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str
return get_async_backend().getnameinfo(sockaddr, flags)


def wait_socket_readable(sock: HasFileno | int) -> Awaitable[None]:
def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
"""
Deprecated, use `wait_readable` instead.
Wait until the given socket has data to be read.
This does **NOT** work on Windows when using the asyncio backend with a proactor
Expand All @@ -606,18 +609,25 @@ def wait_socket_readable(sock: HasFileno | int) -> Awaitable[None]:
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
:param sock: a socket object or its file descriptor
:param sock: a socket object
:raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
socket to become readable
:raises ~anyio.BusyResourceError: if another task is already waiting for the socket
to become readable
"""
warn(
"This function is deprecated; use `wait_readable` instead",
DeprecationWarning,
stacklevel=2,
)
return get_async_backend().wait_socket_readable(sock)


def wait_socket_writable(sock: HasFileno | int) -> Awaitable[None]:
def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
"""
Deprecated, use `wait_writable` instead.
Wait until the given socket can be written to.
This does **NOT** work on Windows when using the asyncio backend with a proactor
Expand All @@ -626,16 +636,73 @@ def wait_socket_writable(sock: HasFileno | int) -> Awaitable[None]:
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
:param sock: a socket object or its file descriptor
:param sock: a socket object
:raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
socket to become writable
:raises ~anyio.BusyResourceError: if another task is already waiting for the socket
to become writable
"""
warn(
"This function is deprecated; use `wait_writable` instead",
DeprecationWarning,
stacklevel=2,
)
return get_async_backend().wait_socket_writable(sock)


def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
"""
Wait until the given object has data to be read.
On Unix systems, ``obj`` must either be an integer file descriptor, or else an
object with a ``.fileno()`` method which returns an integer file descriptor. Any
kind of file descriptor can be passed, though the exact semantics will depend on
your kernel. For example, this probably won't do anything useful for on-disk files.
On Windows systems, ``obj`` must either be an integer ``SOCKET`` handle, or else an
object with a ``.fileno()`` method which returns an integer ``SOCKET`` handle. File
descriptors aren't supported, and neither are handles that refer to anything besides
a ``SOCKET``.
This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
:param obj: an object with a ``.fileno()`` method or an integer handle.
:raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
object to become readable
:raises ~anyio.BusyResourceError: if another task is already waiting for the object
to become readable
"""
return get_async_backend().wait_readable(obj)


def wait_writable(obj: HasFileno | int) -> Awaitable[None]:
"""
Wait until the given object can be written to.
See `wait_readable` for the definition of ``obj``.
This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
:param obj: an object with a ``.fileno()`` method or an integer handle.
:raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
object to become writable
:raises ~anyio.BusyResourceError: if another task is already waiting for the object
to become writable
"""
return get_async_backend().wait_writable(obj)


#
# Private API
#
Expand Down
14 changes: 12 additions & 2 deletions src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,22 @@ async def getnameinfo(

@classmethod
@abstractmethod
async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
async def wait_socket_readable(cls, sock: socket) -> None:
pass

@classmethod
@abstractmethod
async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
async def wait_socket_writable(cls, sock: socket) -> None:
pass

@classmethod
@abstractmethod
async def wait_readable(cls, obj: HasFileno | int) -> None:
pass

@classmethod
@abstractmethod
async def wait_writable(cls, obj: HasFileno | int) -> None:
pass

@classmethod
Expand Down
22 changes: 20 additions & 2 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
getnameinfo,
move_on_after,
wait_all_tasks_blocked,
wait_readable,
wait_socket_readable,
wait_socket_writable,
wait_writable,
)
from anyio.abc import (
IPSockAddrType,
Expand Down Expand Up @@ -1866,7 +1868,7 @@ async def test_wait_socket(
if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")

wait_socket = wait_socket_readable if event == "readable" else wait_socket_writable
wait = wait_readable if event == "readable" else wait_writable

def client(port: int) -> None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
Expand All @@ -1884,4 +1886,20 @@ def client(port: int) -> None:
with conn:
sock_or_fd: HasFileno | int = conn.fileno() if socket_type == "fd" else conn
with fail_after(10):
await wait_socket(sock_or_fd)
await wait(sock_or_fd)


async def test_deprecated_wait_socket() -> None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
with pytest.warns(
DeprecationWarning,
match="This function is deprecated; use `wait_readable` instead",
):
with move_on_after(0.1):
await wait_socket_readable(sock)
with pytest.warns(
DeprecationWarning,
match="This function is deprecated; use `wait_writable` instead",
):
with move_on_after(0.1):
await wait_socket_writable(sock)

0 comments on commit e60655c

Please sign in to comment.