diff --git a/jupyter_client/session.py b/jupyter_client/session.py index c387cd06..f3e7afca 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -28,6 +28,7 @@ # We are using compare_digest to limit the surface of timing attacks import zmq.asyncio +from jupyter_core.utils import run_sync from tornado.ioloop import IOLoop from traitlets import ( Any, @@ -812,7 +813,13 @@ def send( if isinstance(stream, zmq.asyncio.Socket): assert stream is not None # type:ignore[unreachable] - stream = zmq.Socket.shadow(stream.underlying) + + async def _send_multipart(*args, **kwargs): + return await stream.send_multipart(*args, **kwargs) + + send_func = run_sync(_send_multipart) + elif stream is not None: + send_func = stream.send_multipart if isinstance(msg_or_type, (Message, dict)): # We got a Message or message dict, not a msg_type so don't @@ -856,11 +863,11 @@ def send( if stream and buffers and track and not copy: # only really track when we are doing zero-copy buffers - tracker = stream.send_multipart(to_send, copy=False, track=True) + tracker = send_func(to_send, copy=False, track=True) elif stream: # use dummy tracker, which will be done immediately tracker = DONE - stream.send_multipart(to_send, copy=copy) + send_func(to_send, copy=copy) else: tracker = DONE @@ -907,8 +914,15 @@ def send_raw( to_send.append(self.sign(msg_list[0:4])) to_send.extend(msg_list) if isinstance(stream, zmq.asyncio.Socket): - stream = zmq.Socket.shadow(stream.underlying) - stream.send_multipart(to_send, flags, copy=copy) + assert stream is not None + + async def _send_multipart(*args: t.Any, **kwargs: t.Any) -> None: + await stream.send_multipart(*args, **kwargs) + + send_func = run_sync(_send_multipart) + else: + send_func = stream.send_multipart + send_func(to_send, flags, copy=copy) def recv( self, @@ -932,11 +946,18 @@ def recv( """ if isinstance(socket, ZMQStream): # type:ignore[unreachable] socket = socket.socket # type:ignore[unreachable] + if isinstance(socket, zmq.asyncio.Socket): - socket = zmq.Socket.shadow(socket.underlying) + + async def _recv_multipart(*args: t.Any, **kwargs: t.Any) -> t.Any: + return await socket.recv_multipart(*args, **kwargs) + + recv_func = run_sync(_recv_multipart) + else: + recv_func = socket.recv_multipart try: - msg_list = socket.recv_multipart(mode, copy=copy) + msg_list = recv_func(mode, copy=copy) except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: # We can convert EAGAIN to None as we know in this case