From 5f0b954fc36545cb752001cb571397a4955d4270 Mon Sep 17 00:00:00 2001 From: Min RK Date: Mon, 23 Sep 2024 09:00:24 +0200 Subject: [PATCH] more annotations for zmq.asyncio.Socket missing for rarely used pyobj, json methods done via type stub instead of in implementation --- mypy_tests/test_socket_async.py | 53 +++++++++++++++++++ zmq/_future.py | 53 +++---------------- zmq/_future.pyi | 92 +++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 45 deletions(-) create mode 100644 mypy_tests/test_socket_async.py create mode 100644 zmq/_future.pyi diff --git a/mypy_tests/test_socket_async.py b/mypy_tests/test_socket_async.py new file mode 100644 index 000000000..60192deda --- /dev/null +++ b/mypy_tests/test_socket_async.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import asyncio + +import zmq +import zmq.asyncio + + +async def main() -> None: + ctx = zmq.asyncio.Context() + + # shadow exercise + sync_ctx: zmq.Context = zmq.Context.shadow(ctx) + ctx2: zmq.asyncio.Context = zmq.asyncio.Context.shadow(sync_ctx) + ctx2 = zmq.asyncio.Context(sync_ctx) + + url = "tcp://127.0.0.1:5555" + pub = ctx.socket(zmq.PUB) + sub = ctx.socket(zmq.SUB) + pub.bind(url) + sub.connect(url) + sub.subscribe(b"") + await asyncio.sleep(1) + + # shadow exercise + sync_sock: zmq.Socket[bytes] = zmq.Socket.shadow(pub) + s2: zmq.asyncio.Socket = zmq.asyncio.Socket(sync_sock) + s2 = zmq.asyncio.Socket.from_socket(sync_sock) + + print("sending") + await pub.send(b"plain") + await pub.send(b"plain") + await pub.send_multipart([b"topic", b"Message"]) + await pub.send_multipart([b"topic", b"Message"]) + await pub.send_string("asdf") + await pub.send_pyobj(123) + await pub.send_json({"a": "5"}) + + print("receiving") + msg_bytes: bytes = await sub.recv() + msg_frame: zmq.Frame = await sub.recv(copy=False) + msg_list: list[bytes] = await sub.recv_multipart() + msg_frames: list[zmq.Frame] = await sub.recv_multipart(copy=False) + s: str = await sub.recv_string() + obj = await sub.recv_pyobj() + d = await sub.recv_json() + + pub.close() + sub.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/zmq/_future.py b/zmq/_future.py index effd0d564..388284e74 100644 --- a/zmq/_future.py +++ b/zmq/_future.py @@ -9,11 +9,17 @@ from collections import deque from functools import partial from itertools import chain -from typing import Any, Awaitable, Callable, NamedTuple, TypeVar, cast, overload +from typing import ( + Any, + Awaitable, + Callable, + NamedTuple, + TypeVar, + cast, +) import zmq as _zmq from zmq import EVENTS, POLLIN, POLLOUT -from zmq._typing import Literal class _FutureEvent(NamedTuple): @@ -260,27 +266,6 @@ def get(self, key): get.__doc__ = _zmq.Socket.get.__doc__ - @overload # type: ignore - def recv_multipart( - self, flags: int = 0, *, track: bool = False - ) -> Awaitable[list[bytes]]: ... - - @overload - def recv_multipart( - self, flags: int = 0, *, copy: Literal[True], track: bool = False - ) -> Awaitable[list[bytes]]: ... - - @overload - def recv_multipart( - self, flags: int = 0, *, copy: Literal[False], track: bool = False - ) -> Awaitable[list[_zmq.Frame]]: # type: ignore - ... - - @overload - def recv_multipart( - self, flags: int = 0, copy: bool = True, track: bool = False - ) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ... - def recv_multipart( self, flags: int = 0, copy: bool = True, track: bool = False ) -> Awaitable[list[bytes] | list[_zmq.Frame]]: @@ -292,19 +277,6 @@ def recv_multipart( 'recv_multipart', dict(flags=flags, copy=copy, track=track) ) - @overload # type: ignore - def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ... - - @overload - def recv( - self, flags: int = 0, *, copy: Literal[True], track: bool = False - ) -> Awaitable[bytes]: ... - - @overload - def recv( - self, flags: int = 0, *, copy: Literal[False], track: bool = False - ) -> Awaitable[_zmq.Frame]: ... - def recv( # type: ignore self, flags: int = 0, copy: bool = True, track: bool = False ) -> Awaitable[bytes | _zmq.Frame]: @@ -440,15 +412,6 @@ def cancel_poll(future): return future - # overrides only necessary for updated types - def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore - return super().recv_string(*args, **kwargs) # type: ignore - - def send_string( # type: ignore - self, s: str, flags: int = 0, encoding: str = 'utf-8' - ) -> Awaitable[None]: - return super().send_string(s, flags=flags, encoding=encoding) # type: ignore - def _add_timeout(self, future, timeout): """Add a timeout for a send or recv Future""" diff --git a/zmq/_future.pyi b/zmq/_future.pyi new file mode 100644 index 000000000..1f193aac7 --- /dev/null +++ b/zmq/_future.pyi @@ -0,0 +1,92 @@ +"""type annotations for async sockets""" + +from __future__ import annotations + +from asyncio import Future +from pickle import DEFAULT_PROTOCOL +from typing import Any, Awaitable, Literal, Sequence, TypeVar, overload + +import zmq as _zmq + +class _AsyncPoller(_zmq.Poller): + _socket_class: type[_AsyncSocket] + + def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore + +T = TypeVar("T", bound="_AsyncSocket") + +class _AsyncSocket(_zmq.Socket[Future]): + @classmethod + def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T: ... + def send( # type: ignore + self, + data: Any, + flags: int = 0, + copy: bool = True, + track: bool = False, + routing_id: int | None = None, + group: str | None = None, + ) -> Awaitable[_zmq.MessageTracker | None]: ... + @overload # type: ignore + def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ... + @overload + def recv( + self, flags: int = 0, *, copy: Literal[True], track: bool = False + ) -> Awaitable[bytes]: ... + @overload + def recv( + self, flags: int = 0, *, copy: Literal[False], track: bool = False + ) -> Awaitable[_zmq.Frame]: ... + @overload + def recv( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[bytes | _zmq.Frame]: ... + def send_multipart( # type: ignore + self, + msg_parts: Sequence, + flags: int = 0, + copy: bool = True, + track: bool = False, + routing_id: int | None = None, + group: str | None = None, + ) -> Awaitable[_zmq.MessageTracker | None]: ... + @overload # type: ignore + def recv_multipart( + self, flags: int = 0, *, track: bool = False + ) -> Awaitable[list[bytes]]: ... + @overload + def recv_multipart( + self, flags: int = 0, *, copy: Literal[True], track: bool = False + ) -> Awaitable[list[bytes]]: ... + @overload + def recv_multipart( + self, flags: int = 0, *, copy: Literal[False], track: bool = False + ) -> Awaitable[list[_zmq.Frame]]: ... + @overload + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ... + + # serialization wrappers + + def send_string( # type: ignore + self, + u: str, + flags: int = 0, + copy: bool = True, + *, + encoding: str = 'utf-8', + **kwargs, + ) -> Awaitable[_zmq.Frame | None]: ... + def recv_string( # type: ignore + self, flags: int = 0, encoding: str = 'utf-8' + ) -> Awaitable[str]: ... + def send_pyobj( # type: ignore + self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs + ) -> Awaitable[_zmq.Frame | None]: ... + def recv_pyobj(self, flags: int = 0) -> Awaitable[Any]: ... # type: ignore + def send_json( # type: ignore + self, obj: Any, flags: int = 0, **kwargs + ) -> Awaitable[_zmq.Frame | None]: ... + def recv_json(self, flags: int = 0, **kwargs) -> Awaitable[Any]: ... # type: ignore + def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore