Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Subscribe events - PubSub Extension #5218

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from .protocol import to_serialize
from .protocol.pickle import dumps, loads
from .publish import Datasets
from .pubsub import PubSubClientExtension
from .pubsub import PubSubClientExtension, Sub
from .security import Security
from .sizeof import sizeof
from .threadpoolexecutor import rejoin
Expand Down Expand Up @@ -710,6 +710,8 @@ def __init__(
self._set_config = dask.config.set(
scheduler="dask.distributed", shuffle="tasks"
)
self.event_handlers = {}
self.event_subscriptions = {}

self._stream_handlers = {
"key-in-memory": self._handle_key_in_memory,
Expand All @@ -719,6 +721,7 @@ def __init__(
"task-erred": self._handle_task_erred,
"restart": self._handle_restart,
"error": self._handle_error,
# "event": self._handle_event,
}

self._state_handlers = {
Expand Down Expand Up @@ -1292,6 +1295,9 @@ async def _close(self, fast=False):
for pc in self._periodic_callbacks.values():
pc.stop()

for topic in set(self.event_subscriptions):
self.unsubscribe_topic(topic)

with log_errors():
_del_global_client(self)
self._scheduler_identity = {}
Expand Down Expand Up @@ -3562,7 +3568,17 @@ def log_event(self, topic, msg):
>>> from time import time
>>> client.log_event("current-time", time())
"""
return self.sync(self.scheduler.log_event, topic=topic, msg=msg)
from .pubsub import Pub

if not hasattr(self, "event_publishers"):
self.event_publishers = {}
pub = self.event_publishers.get(topic)
if pub is None:
self.event_publishers[topic] = pub = Pub(
topic,
log_queue=True,
)
pub.put(msg)

def get_events(self, topic: str = None):
"""Retrieve structured topic logs
Expand All @@ -3575,6 +3591,68 @@ def get_events(self, topic: str = None):
"""
return self.sync(self.scheduler.events, topic=topic)

async def _handle_events(self, sub, handler):
while self.status == "running":
event = await sub.get()
ret = handler(event)
if inspect.isawaitable(ret):
await ret

def subscribe_topic(self, topic, handler):
"""Subscribe to a topic and execute a handler for every received event

Parameters
----------
topic: str
The topic name
handler: callable or coroutinefunction
A handler called for every received event. The handler must accept a
single argument `event` which is a tuple `(timestamp, msg)` where
timestamp refers to the clock on the scheduler.

Example
-------

>>> import logging
>>> logger = logging.getLogger("myLogger") # Log config not shown
>>> client.subscribe_topic("topic-name", lambda: logger.info)

See Also
--------
dask.distributed.Client.unsubscribe_topic
dask.distributed.Client.get_events
dask.distributed.Client.log_event
"""
if topic not in self.event_subscriptions:
sub = self.event_subscriptions[topic] = Sub(topic, client=self)
else:
sub = self.event_subscriptions[topic]

if topic in self.event_handlers:
fut = self.event_handlers[topic]
fut.cancel()

self.event_handlers[topic] = fut = asyncio.ensure_future(
self._handle_events(sub=sub, handler=handler)
)

def unsubscribe_topic(self, topic):
"""Unsubscribe from a topic and remove event handler

See Also
--------
dask.distributed.Client.subscribe_topic
dask.distributed.Client.get_events
dask.distributed.Client.log_event
"""
if topic in self.event_handlers:
fut = self.event_handlers.pop(topic)
fut.cancel()
sub = self.event_subscriptions.pop(topic)
sub.stop()
else:
Comment on lines +3626 to +3653
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly convenience stuff to execute an event handler on client side. The actual subscription is handled via the extension

raise ValueError(f"No event handler known for topic {topic}.")

def retire_workers(self, workers=None, close_workers=True, **kwargs):
"""Retire certain workers on the scheduler

Expand Down
12 changes: 12 additions & 0 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .process import AsyncProcess
from .proctitle import enable_proctitle_on_children
from .protocol import pickle
from .pubsub import Pub
from .security import Security
from .utils import (
TimeoutError,
Expand Down Expand Up @@ -607,6 +608,17 @@ async def close(self, comm=None, timeout=5, report=None):
await comm.write("OK")
await super().close()

def log_event(self, topic, msg):
if not hasattr(self, "event_publishers"):
self.event_publishers = {}
pub = self.event_publishers.get(topic)
if pub is None:
self.event_publishers[topic] = pub = Pub(
topic,
log_queue=True,
)
pub.put(msg)


class WorkerProcess:
# The interval how often to check the msg queue for init
Expand Down
59 changes: 43 additions & 16 deletions distributed/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .core import CommClosedError
from .metrics import time
from .protocol.serialize import to_serialize
from .protocol.serialize import Serialized, deserialize # , to_serialize
from .utils import TimeoutError, sync

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,13 +36,21 @@ def __init__(self, scheduler):

self.scheduler.extensions["pubsub"] = self

def add_publisher(self, comm=None, name=None, worker=None):
def add_publisher(self, comm=None, name=None, worker=None, log_queue=False):
logger.debug("Add publisher: %s %s", name, worker)
self.publishers[name].add(worker)

if log_queue:
self.scheduler.events[name] # init defaultdict
assert name in self.scheduler.events

return {
"subscribers": {addr: {} for addr in self.subscribers[name]},
"publish-scheduler": name in self.client_subscribers
and len(self.client_subscribers[name]) > 0,
"publish-scheduler": (
name in self.client_subscribers
and len(self.client_subscribers[name]) > 0
)
or name in self.scheduler.events,
}

def add_subscriber(self, comm=None, name=None, worker=None, client=None):
Expand Down Expand Up @@ -75,15 +83,15 @@ def remove_publisher(self, comm=None, name=None, worker=None):
def remove_subscriber(self, comm=None, name=None, worker=None, client=None):
if worker:
logger.debug("Remove worker subscriber: %s %s", name, worker)
self.subscribers[name].remove(worker)
self.subscribers[name].discard(worker)
for pub in self.publishers[name]:
self.scheduler.worker_send(
pub,
{"op": "pubsub-remove-subscriber", "address": worker, "name": name},
)
elif client:
logger.debug("Remove client subscriber: %s %s", name, client)
self.client_subscribers[name].remove(client)
self.client_subscribers[name].discard(client)
if not self.client_subscribers[name]:
del self.client_subscribers[name]
for pub in self.publishers[name]:
Expand All @@ -110,6 +118,14 @@ def handle_message(self, name=None, msg=None, worker=None, client=None):
except (KeyError, CommClosedError):
self.remove_subscriber(name=name, client=c)

if name in self.scheduler.events:
# FIXME: Am I allowed to do this? Feels evil
if isinstance(msg, Serialized):
msg = deserialize(msg.header, msg.frames)
event = (time(), msg)
self.scheduler.events[name].append(event)
self.scheduler.event_counts[name] += 1

if client:
for sub in self.subscribers[name]:
self.scheduler.worker_send(
Expand Down Expand Up @@ -143,7 +159,7 @@ def add_subscriber(self, name=None, address=None, **info):

def remove_subscriber(self, name=None, address=None):
for pub in self.publishers[name]:
del pub.subscribers[address]
pub.subscribers.pop(address, None)

def publish_scheduler(self, name=None, publish=None):
self.publish_to_scheduler[name] = publish
Expand Down Expand Up @@ -247,6 +263,8 @@ class Pub:
client: Client (optional)
Client used for communication with the scheduler. Defaults to
the value of ``get_client()``. If given, ``worker`` must be ``None``.
log_queue: bool
If True, log the events in an event queue on the scheduler.

Examples
--------
Expand Down Expand Up @@ -283,7 +301,7 @@ class Pub:
Sub
"""

def __init__(self, name, worker=None, client=None):
def __init__(self, name, worker=None, client=None, log_queue=False):
if worker is None and client is None:
from distributed import get_client, get_worker

Expand All @@ -304,6 +322,7 @@ def __init__(self, name, worker=None, client=None):
self.loop = self.client.loop

self.name = name
self.log_queue = log_queue
self._started = False
self._buffer = []

Expand All @@ -317,7 +336,7 @@ def __init__(self, name, worker=None, client=None):
async def _start(self):
if self.worker:
result = await self.scheduler.pubsub_add_publisher(
name=self.name, worker=self.worker.address
name=self.name, worker=self.worker.address, log_queue=self.log_queue
)
pubsub = self.worker.extensions["pubsub"]
self.subscribers.update(result["subscribers"])
Expand All @@ -334,7 +353,8 @@ def _put(self, msg):
self._buffer.append(msg)
return

data = {"op": "pubsub-msg", "name": self.name, "msg": to_serialize(msg)}
# FIXME If I use to_serialize here, this breaks msgs of type dict!
data = {"op": "pubsub-msg", "name": self.name, "msg": msg}
Comment on lines +356 to +357
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This broke for simple dict type messages by raising a keyerror on on the remote side. probably a bug but how to deal with the serialized objects in general would be one of my questions


if self.worker:
for sub in self.subscribers:
Expand Down Expand Up @@ -388,12 +408,7 @@ def __init__(self, name, worker=None, client=None):
self.loop.add_callback(pubsub.subscribers[name].add, self)

msg = {"op": "pubsub-add-subscriber", "name": self.name}
if self.worker:
self.loop.add_callback(self.worker.batched_stream.send, msg)
elif self.client:
self.loop.add_callback(self.client.scheduler_comm.send, msg)
else:
raise Exception()
self._send_message(msg)

weakref.finalize(self, pubsub.trigger_cleanup)

Expand Down Expand Up @@ -461,6 +476,18 @@ async def _put(self, msg):
async with self.condition:
self.condition.notify()

def _send_message(self, msg):
if self.worker:
self.loop.add_callback(self.worker.batched_stream.send, msg)
elif self.client:
self.loop.add_callback(self.client.scheduler_comm.send, msg)
else:
raise Exception()

def stop(self):
msg = {"op": "pubsub-remove-subscriber", "name": self.name}
self._send_message(msg)

def __repr__(self):
return f"<Sub: {self.name}>"

Expand Down
13 changes: 5 additions & 8 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7399,14 +7399,11 @@ async def get_worker_logs(self, comm=None, n=None, workers=None, nanny=False):
return results

def log_event(self, name, msg):
event = (time(), msg)
if isinstance(name, list):
for n in name:
self.events[n].append(event)
self.event_counts[n] += 1
else:
self.events[name].append(event)
self.event_counts[name] += 1
if not isinstance(name, list):
name = [name]
for n in name:
self.events[n] # init defaultdict
self.extensions["pubsub"].handle_message(name=n, msg=msg)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently ugly but the handle_message of the extension will append the event to the deque and inc the counter now


def get_events(self, comm=None, topic=None):
if topic is not None:
Expand Down
Loading