Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more annotations for zmq.asyncio.Socket #2035

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions mypy_tests/test_socket_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import asyncio

import zmq
import zmq.asyncio


async def main() -> None:
ctx = zmq.asyncio.Context()

# shadow exercise
sync_ctx: zmq.Context = zmq.Context.shadow(ctx)
ctx2: zmq.asyncio.Context = zmq.asyncio.Context.shadow(sync_ctx)
ctx2 = zmq.asyncio.Context(sync_ctx)

url = "tcp://127.0.0.1:5555"
pub = ctx.socket(zmq.PUB)
sub = ctx.socket(zmq.SUB)
pub.bind(url)
sub.connect(url)
sub.subscribe(b"")
await asyncio.sleep(1)

# shadow exercise
sync_sock: zmq.Socket[bytes] = zmq.Socket.shadow(pub)
s2: zmq.asyncio.Socket = zmq.asyncio.Socket(sync_sock)
s2 = zmq.asyncio.Socket.from_socket(sync_sock)

print("sending")
await pub.send(b"plain")
await pub.send(b"plain")
await pub.send_multipart([b"topic", b"Message"])
await pub.send_multipart([b"topic", b"Message"])
await pub.send_string("asdf")
await pub.send_pyobj(123)
await pub.send_json({"a": "5"})

print("receiving")
msg_bytes: bytes = await sub.recv()
msg_frame: zmq.Frame = await sub.recv(copy=False)
msg_list: list[bytes] = await sub.recv_multipart()
msg_frames: list[zmq.Frame] = await sub.recv_multipart(copy=False)
s: str = await sub.recv_string()
obj = await sub.recv_pyobj()
d = await sub.recv_json()

pub.close()
sub.close()


if __name__ == "__main__":
asyncio.run(main())
53 changes: 8 additions & 45 deletions zmq/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from collections import deque
from functools import partial
from itertools import chain
from typing import Any, Awaitable, Callable, NamedTuple, TypeVar, cast, overload
from typing import (
Any,
Awaitable,
Callable,
NamedTuple,
TypeVar,
cast,
)

import zmq as _zmq
from zmq import EVENTS, POLLIN, POLLOUT
from zmq._typing import Literal


class _FutureEvent(NamedTuple):
Expand Down Expand Up @@ -260,27 +266,6 @@ def get(self, key):

get.__doc__ = _zmq.Socket.get.__doc__

@overload # type: ignore
def recv_multipart(
self, flags: int = 0, *, track: bool = False
) -> Awaitable[list[bytes]]: ...

@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[list[bytes]]: ...

@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[list[_zmq.Frame]]: # type: ignore
...

@overload
def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...

def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]:
Expand All @@ -292,19 +277,6 @@ def recv_multipart(
'recv_multipart', dict(flags=flags, copy=copy, track=track)
)

@overload # type: ignore
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...

@overload
def recv(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[bytes]: ...

@overload
def recv(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[_zmq.Frame]: ...

def recv( # type: ignore
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[bytes | _zmq.Frame]:
Expand Down Expand Up @@ -440,15 +412,6 @@ def cancel_poll(future):

return future

# overrides only necessary for updated types
def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
return super().recv_string(*args, **kwargs) # type: ignore

def send_string( # type: ignore
self, s: str, flags: int = 0, encoding: str = 'utf-8'
) -> Awaitable[None]:
return super().send_string(s, flags=flags, encoding=encoding) # type: ignore

def _add_timeout(self, future, timeout):
"""Add a timeout for a send or recv Future"""

Expand Down
92 changes: 92 additions & 0 deletions zmq/_future.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""type annotations for async sockets"""

from __future__ import annotations

from asyncio import Future
from pickle import DEFAULT_PROTOCOL
from typing import Any, Awaitable, Literal, Sequence, TypeVar, overload

import zmq as _zmq

class _AsyncPoller(_zmq.Poller):
_socket_class: type[_AsyncSocket]

def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore

T = TypeVar("T", bound="_AsyncSocket")

class _AsyncSocket(_zmq.Socket[Future]):
@classmethod
def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T: ...
def send( # type: ignore
self,
data: Any,
flags: int = 0,
copy: bool = True,
track: bool = False,
routing_id: int | None = None,
group: str | None = None,
) -> Awaitable[_zmq.MessageTracker | None]: ...
@overload # type: ignore
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...
@overload
def recv(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[bytes]: ...
@overload
def recv(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[_zmq.Frame]: ...
@overload
def recv(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[bytes | _zmq.Frame]: ...
def send_multipart( # type: ignore
self,
msg_parts: Sequence,
flags: int = 0,
copy: bool = True,
track: bool = False,
routing_id: int | None = None,
group: str | None = None,
) -> Awaitable[_zmq.MessageTracker | None]: ...
@overload # type: ignore
def recv_multipart(
self, flags: int = 0, *, track: bool = False
) -> Awaitable[list[bytes]]: ...
@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[list[bytes]]: ...
@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[list[_zmq.Frame]]: ...
@overload
def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...

# serialization wrappers

def send_string( # type: ignore
self,
u: str,
flags: int = 0,
copy: bool = True,
*,
encoding: str = 'utf-8',
**kwargs,
) -> Awaitable[_zmq.Frame | None]: ...
def recv_string( # type: ignore
self, flags: int = 0, encoding: str = 'utf-8'
) -> Awaitable[str]: ...
def send_pyobj( # type: ignore
self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs
) -> Awaitable[_zmq.Frame | None]: ...
def recv_pyobj(self, flags: int = 0) -> Awaitable[Any]: ... # type: ignore
def send_json( # type: ignore
self, obj: Any, flags: int = 0, **kwargs
) -> Awaitable[_zmq.Frame | None]: ...
def recv_json(self, flags: int = 0, **kwargs) -> Awaitable[Any]: ... # type: ignore
def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore
Loading