From 0c057ba27f4f8f055eef0f3c78fc0e85fdc94ae0 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Mon, 19 Jul 2021 09:27:54 -0400 Subject: [PATCH 01/10] Start to extract SchedulerState --- distributed/scheduler.py | 240 +++++++++++++-------------------------- 1 file changed, 81 insertions(+), 159 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9c25d2c85b..1d8bea6c68 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -24,6 +24,7 @@ import sortedcontainers from tlz import ( compose, + concat, first, groupby, merge, @@ -3281,7 +3282,7 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: return (start_time, ws._nbytes) -class Scheduler(SchedulerState, ServerNode): +class Scheduler(ServerNode): """Dynamic distributed task scheduler The scheduler tracks the current state of workers, data, and computations. @@ -3439,18 +3440,19 @@ def __init__( http_server_modules = dask.config.get("distributed.scheduler.http.routes") show_dashboard = dashboard or (dashboard is None and dashboard_address) + missing_bokeh = False # install vanilla route if show_dashboard but bokeh is not installed if show_dashboard: try: import distributed.dashboard.scheduler except ImportError: - show_dashboard = False + missing_bokeh = True http_server_modules.append("distributed.http.scheduler.missing_bokeh") routes = get_handlers( server=self, modules=http_server_modules, prefix=http_prefix ) self.start_http_server(routes, dashboard_address, default_port=8787) - if show_dashboard: + if show_dashboard and not missing_bokeh: distributed.dashboard.scheduler.connect( self.http_application, self.http_server, self, prefix=http_prefix ) @@ -3642,8 +3644,7 @@ def __init__( } connection_limit = get_fileno_limit() / 2 - - super().__init__( + self.state: SchedulerState = SchedulerState( aliases=aliases, handlers=self.handlers, stream_handlers=merge(worker_handlers, client_handlers), @@ -3686,7 +3687,7 @@ def __init__( ################## def __repr__(self): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state return '' % ( self.address, len(parent._workers_dv), @@ -3695,7 +3696,7 @@ def __repr__(self): ) def _repr_html_(self): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state text = ( f"Scheduler: {html.escape(self.address)} " f'workers: {len(parent._workers_dv)} ' @@ -3706,7 +3707,7 @@ def _repr_html_(self): def identity(self, comm=None): """Basic information about ourselves and our cluster""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state d = { "type": type(self).__name__, "id": str(self.id), @@ -3734,7 +3735,7 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): Whether or not to include a full address with protocol (True) or just a (host, port) pair """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState = parent._workers_dv[worker] port = ws._services.get(service_name) if port is None: @@ -3945,9 +3946,7 @@ def heartbeat_worker( ws._last_seen = local_now if executing is not None: ws._executing = { - parent._tasks[key]: duration - for key, duration in executing.items() - if key in parent._tasks + parent._tasks[key]: duration for key, duration in executing.items() } ws._metrics = metrics @@ -5391,7 +5390,7 @@ async def scatter( return keys async def gather(self, comm=None, keys=None, serializers=None): - """Collect data from workers to the scheduler""" + """Collect data in from workers""" parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState keys = list(keys) @@ -5467,8 +5466,8 @@ def clear_task_state(self): for collection in self._task_state_collections: collection.clear() - async def restart(self, client=None, timeout=30): - """Restart all workers. Reset local state.""" + async def restart(self, client=None, timeout=3): + """Restart all workers. Reset local state.""" parent: SchedulerState = cast(SchedulerState, self) with log_errors(): @@ -5597,108 +5596,31 @@ async def proxy(self, comm=None, msg=None, worker=None, serializers=None): ) return d[worker] - async def _gather_on_worker( - self, worker_address: str, who_has: "dict[Hashable, list[str]]" - ) -> set: - """Peer-to-peer copy of keys from multiple workers to a single worker - - Parameters - ---------- - worker_address: str - Recipient worker address to copy keys to - who_has: dict[Hashable, list[str]] - {key: [sender address, sender address, ...], key: ...} - - Returns - ------- - returns: - set of keys that failed to be copied - """ - try: - result = await retry_operation( - self.rpc(addr=worker_address).gather, who_has=who_has - ) - except OSError as e: - # This can happen e.g. if the worker is going through controlled shutdown; - # it doesn't necessarily mean that it went unexpectedly missing - logger.warning( - f"Communication with worker {worker_address} failed during " - f"replication: {e.__class__.__name__}: {e}" - ) - return set(who_has) - - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker_address) - - if ws is None: - logger.warning(f"Worker {worker_address} lost during replication") - return set(who_has) - elif result["status"] == "OK": - keys_failed = set() - keys_ok = who_has.keys() - elif result["status"] == "partial-fail": - keys_failed = set(result["keys"]) - keys_ok = who_has.keys() - keys_failed - logger.warning( - f"Worker {worker_address} failed to acquire keys: {result['keys']}" - ) - else: # pragma: nocover - raise ValueError(f"Unexpected message from {worker_address}: {result}") - - for key in keys_ok: - ts: TaskState = parent._tasks.get(key) - if ts is None or ts._state != "memory": - logger.warning(f"Key lost during replication: {key}") - continue - if ts not in ws._has_what: - ws._nbytes += ts.get_nbytes() - ws._has_what[ts] = None - ts._who_has.add(ws) - - return keys_failed - - async def _delete_worker_data(self, worker_address: str, keys: "list[str]") -> None: + async def _delete_worker_data(self, worker_address, keys): """Delete data from a worker and update the corresponding worker/task states Parameters ---------- worker_address: str Worker address to delete keys from - keys: list[str] + keys: List[str] List of keys to delete on the specified worker """ parent: SchedulerState = cast(SchedulerState, self) - try: - await retry_operation( - self.rpc(addr=worker_address).free_keys, - keys=list(keys), - reason="rebalance/replicate", - ) - except OSError as e: - # This can happen e.g. if the worker is going through controlled shutdown; - # it doesn't necessarily mean that it went unexpectedly missing - logger.warning( - f"Communication with worker {worker_address} failed during " - f"replication: {e.__class__.__name__}: {e}" - ) - return - - ws: WorkerState = parent._workers_dv.get(worker_address) - if ws is None: - return - - for key in keys: - ts: TaskState = parent._tasks.get(key) - if ts is not None and ts in ws._has_what: - assert ts._state == "memory" - del ws._has_what[ts] - ts._who_has.remove(ws) - ws._nbytes -= ts.get_nbytes() - if not ts._who_has: - # Last copy deleted - self.transitions({key: "released"}) + await retry_operation( + self.rpc(addr=worker_address).free_keys, + keys=list(keys), + reason="rebalance/replicate", + ) + ws: WorkerState = parent._workers_dv[worker_address] + ts: TaskState + tasks: set = {parent._tasks[key] for key in keys} + for ts in tasks: + del ws._has_what[ts] + ts._who_has.remove(ws) + ws._nbytes -= ts.get_nbytes() self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) async def rebalance( @@ -5793,18 +5715,14 @@ async def rebalance( if k not in parent._tasks or not parent._tasks[k].who_has ] if missing_data: - return {"status": "partial-fail", "keys": missing_data} + return {"status": "missing-data", "keys": missing_data} msgs = self._rebalance_find_msgs(keys, workers) if not msgs: return {"status": "OK"} async with self._lock: - result = await self._rebalance_move_data(msgs) - if result["status"] == "partial-fail" and keys is None: - # Only return failed keys if the client explicitly asked for them - result = {"status": "OK"} - return result + return await self._rebalance_move_data(msgs) def _rebalance_find_msgs( self: SchedulerState, @@ -5961,7 +5879,7 @@ def _rebalance_find_msgs( # move on to the next task of the same sender. continue - # Schedule task for transfer from sender to recipient + # Schedule task for transfer from sender to receiver msgs.append((snd_ws, rec_ws, ts)) # *_bytes_max/min are all negative for heap sorting @@ -5982,7 +5900,7 @@ def _rebalance_find_msgs( else: heapq.heappop(senders) - # If recipient still has bytes to gain, push it back into the recipients + # If receiver still has bytes to gain, push it back into the receivers # heap; it may or may not come back on top again. if rec_bytes_min < 0: # See definition of recipients above @@ -6007,46 +5925,29 @@ async def _rebalance_move_data( self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" ) -> dict: """Perform the actual transfer of data across the network in rebalance(). - Takes in input the output of _rebalance_find_msgs(), that is a list of tuples: - - - sender worker - - recipient worker - - task to be transferred + Takes in input the output of _rebalance_find_msgs(). FIXME this method is not robust when the cluster is not idle. """ + ts: TaskState snd_ws: WorkerState rec_ws: WorkerState - ts: TaskState to_recipients = defaultdict(lambda: defaultdict(list)) - for snd_ws, rec_ws, ts in msgs: - to_recipients[rec_ws.address][ts._key].append(snd_ws.address) - failed_keys_by_recipient = dict( - zip( - to_recipients, - await asyncio.gather( - *( - # Note: this never raises exceptions - self._gather_on_worker(w, who_has) - for w, who_has in to_recipients.items() - ) - ), - ) - ) - to_senders = defaultdict(list) - for snd_ws, rec_ws, ts in msgs: - if ts._key not in failed_keys_by_recipient[rec_ws.address]: - to_senders[snd_ws.address].append(ts._key) + for sender, recipient, ts in msgs: + to_recipients[recipient.address][ts._key].append(sender.address) + to_senders[sender.address].append(ts._key) - # Note: this never raises exceptions - await asyncio.gather( - *(self._delete_worker_data(r, v) for r, v in to_senders.items()) + result = await asyncio.gather( + *( + retry_operation(self.rpc(addr=r).gather, who_has=v) + for r, v in to_recipients.items() + ) ) - for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) + self.log_event( "all", { @@ -6057,11 +5958,31 @@ async def _rebalance_move_data( }, ) - missing_keys = {k for r in failed_keys_by_recipient.values() for k in r} - if missing_keys: - return {"status": "partial-fail", "keys": list(missing_keys)} - else: - return {"status": "OK"} + if any(r["status"] != "OK" for r in result): + return { + "status": "missing-data", + "keys": list( + concat( + r["keys"].keys() + for r in result + if r["status"] == "missing-data" + ) + ), + } + + for snd_ws, rec_ws, ts in msgs: + assert ts._state == "memory" + ts._who_has.add(rec_ws) + rec_ws._has_what[ts] = None + rec_ws.nbytes += ts.get_nbytes() + self.log.append( + ("rebalance", ts._key, time(), snd_ws.address, rec_ws.address) + ) + + await asyncio.gather( + *(self._delete_worker_data(r, v) for r, v in to_senders.items()) + ) + return {"status": "OK"} async def replicate( self, @@ -6112,7 +6033,7 @@ async def replicate( tasks = {parent._tasks[k] for k in keys} missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: - return {"status": "partial-fail", "keys": missing_data} + return {"status": "missing-data", "keys": missing_data} # Delete extraneous data if delete: @@ -6125,7 +6046,6 @@ async def replicate( ): del_worker_tasks[ws].add(ts) - # Note: this never raises exceptions await asyncio.gather( *[ self._delete_worker_data(ws._address, [t.key for t in tasks]) @@ -6155,15 +6075,19 @@ async def replicate( wws._address for wws in ts._who_has ] - await asyncio.gather( + results = await asyncio.gather( *( - # Note: this never raises exceptions - self._gather_on_worker(w, who_has) + retry_operation(self.rpc(addr=w).gather, who_has=who_has) for w, who_has in gathers.items() ) ) - for r, v in gathers.items(): - self.log_event(r, {"action": "replicate-add", "who_has": v}) + for w, v in zip(gathers, results): + if v["status"] == "OK": + self.add_keys(worker=w, keys=list(gathers[w])) + else: + logger.warning("Communication failed during replication: %s", v) + + self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) self.log_event( "all", @@ -7027,9 +6951,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} - async def performance_report( - self, comm=None, start=None, last_count=None, code="", mode=None - ): + async def performance_report(self, comm=None, start=None, last_count=None, code=""): parent: SchedulerState = cast(SchedulerState, self) stop = time() # Profiles @@ -7169,7 +7091,7 @@ def profile_to_figure(state): from bokeh.plotting import output_file, save with tmpfile(extension=".html") as fn: - output_file(filename=fn, title="Dask Performance Report", mode=mode) + output_file(filename=fn, title="Dask Performance Report") template_directory = os.path.join( os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates" ) @@ -7729,7 +7651,7 @@ def validate_task_state(ts: TaskState): assert dts._state != "forgotten" assert (ts._processing_on is not None) == (ts._state == "processing") - assert bool(ts._who_has) == (ts._state == "memory"), (ts, ts._who_has, ts._state) + assert (not not ts._who_has) == (ts._state == "memory"), (ts, ts._who_has) if ts._state == "processing": assert all([dts._who_has for dts in ts._dependencies]), ( From 3d89be57dde7d9da217c1f63d3047fe63a6d0785 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 20 Jul 2021 09:29:08 -0400 Subject: [PATCH 02/10] Move scheduler properties --- distributed/scheduler.py | 297 ++++++++++++++++++++------------------- 1 file changed, 150 insertions(+), 147 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1d8bea6c68..8dc1dd3a75 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1719,6 +1719,7 @@ def _task_key_or_none(task: TaskState): return task._key if task is not None else None +@final @cclass class SchedulerState: """Underlying task state of dynamic scheduler @@ -1894,149 +1895,6 @@ def __init__( ) self._transition_counter = 0 - super().__init__(**kwargs) - - @property - def aliases(self): - return self._aliases - - @property - def bandwidth(self): - return self._bandwidth - - @property - def clients(self): - return self._clients - - @property - def extensions(self): - return self._extensions - - @property - def host_info(self): - return self._host_info - - @property - def idle(self): - return self._idle - - @property - def n_tasks(self): - return self._n_tasks - - @property - def resources(self): - return self._resources - - @property - def saturated(self): - return self._saturated - - @property - def tasks(self): - return self._tasks - - @property - def task_groups(self): - return self._task_groups - - @property - def task_prefixes(self): - return self._task_prefixes - - @property - def task_metadata(self): - return self._task_metadata - - @property - def total_nthreads(self): - return self._total_nthreads - - @property - def total_occupancy(self): - return self._total_occupancy - - @total_occupancy.setter - def total_occupancy(self, v: double): - self._total_occupancy = v - - @property - def transition_counter(self): - return self._transition_counter - - @property - def unknown_durations(self): - return self._unknown_durations - - @property - def unrunnable(self): - return self._unrunnable - - @property - def validate(self): - return self._validate - - @validate.setter - def validate(self, v: bint): - self._validate = v - - @property - def workers(self): - return self._workers - - @property - def memory(self) -> MemoryState: - return MemoryState.sum(*(w.memory for w in self.workers.values())) - - @property - def __pdict__(self): - return { - "bandwidth": self._bandwidth, - "resources": self._resources, - "saturated": self._saturated, - "unrunnable": self._unrunnable, - "n_tasks": self._n_tasks, - "unknown_durations": self._unknown_durations, - "validate": self._validate, - "tasks": self._tasks, - "task_groups": self._task_groups, - "task_prefixes": self._task_prefixes, - "total_nthreads": self._total_nthreads, - "total_occupancy": self._total_occupancy, - "extensions": self._extensions, - "clients": self._clients, - "workers": self._workers, - "idle": self._idle, - "host_info": self._host_info, - } - - @ccall - @exceptval(check=False) - def new_task(self, key: str, spec: object, state: str) -> TaskState: - """Create a new task, and associated states""" - ts: TaskState = TaskState(key, spec) - ts._state = state - - tp: TaskPrefix - prefix_key = key_split(key) - tp = self._task_prefixes.get(prefix_key) - if tp is None: - self._task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) - ts._prefix = tp - - tg: TaskGroup - group_key = ts._group_key - tg = self._task_groups.get(group_key) - if tg is None: - self._task_groups[group_key] = tg = TaskGroup(group_key) - tg._prefix = tp - tp._groups.append(tg) - tg.add(ts) - - self._tasks[key] = ts - - return ts - ##################### # State Transitions # ##################### @@ -3545,9 +3403,9 @@ def __init__( resources = dict() aliases = dict() - self._task_state_collections = [unrunnable] + self._task_state_collections: list = [unrunnable] - self._worker_collections = [ + self._worker_collections: list = [ workers, host_info, resources, @@ -3569,7 +3427,7 @@ def __init__( self.event_counts = defaultdict(int) self.worker_plugins = dict() - worker_handlers = { + worker_handlers: dict = { "task-finished": self.handle_task_finished, "task-erred": self.handle_task_erred, "release": self.handle_release_data, @@ -3582,7 +3440,7 @@ def __init__( "log-event": self.log_worker_event, } - client_handlers = { + client_handlers: dict = { "update-graph": self.update_graph, "update-graph-hlg": self.update_graph_hlg, "client-desires-keys": self.client_desires_keys, @@ -3682,6 +3540,151 @@ def __init__( self.rpc.allow_offload = False self.status = Status.undefined + #################### + # state properties # + #################### + + @property + def aliases(self): + return self.state._aliases + + @property + def bandwidth(self): + return self.state._bandwidth + + @property + def clients(self): + return self.state._clients + + @property + def extensions(self): + return self.state._extensions + + @property + def host_info(self): + return self.state._host_info + + @property + def idle(self): + return self.state._idle + + @property + def n_tasks(self): + return self.state._n_tasks + + @property + def resources(self): + return self.state._resources + + @property + def saturated(self): + return self.state._saturated + + @property + def tasks(self): + return self.state._tasks + + @property + def task_groups(self): + return self.state._task_groups + + @property + def task_prefixes(self): + return self.state._task_prefixes + + @property + def task_metadata(self): + return self.state._task_metadata + + @property + def total_nthreads(self): + return self.state._total_nthreads + + @property + def total_occupancy(self): + return self.state._total_occupancy + + @total_occupancy.setter + def total_occupancy(self, v: double): + self.state._total_occupancy = v + + @property + def transition_counter(self): + return self.state._transition_counter + + @property + def unknown_durations(self): + return self.state._unknown_durations + + @property + def unrunnable(self): + return self.state._unrunnable + + @property + def validate(self): + return self.state._validate + + @validate.setter + def validate(self, v: bint): + self.state._validate = v + + @property + def workers(self): + return self.state._workers + + @property + def memory(self) -> MemoryState: + return MemoryState.sum(*(w.memory for w in self.state.workers.values())) + + @property + def __pdict__(self): + return { + "bandwidth": self.state._bandwidth, + "resources": self.state._resources, + "saturated": self.state._saturated, + "unrunnable": self.state._unrunnable, + "n_tasks": self.state._n_tasks, + "unknown_durations": self.state._unknown_durations, + "validate": self.state._validate, + "tasks": self.state._tasks, + "task_groups": self.state._task_groups, + "task_prefixes": self.state._task_prefixes, + "total_nthreads": self.state._total_nthreads, + "total_occupancy": self.state._total_occupancy, + "extensions": self.state._extensions, + "clients": self.state._clients, + "workers": self.state._workers, + "idle": self.state._idle, + "host_info": self.state._host_info, + } + + @ccall + @exceptval(check=False) + def new_task(self, key: str, spec: object, state: str) -> TaskState: + """Create a new task, and associated states""" + ts: TaskState = TaskState(key, spec) + ts._state = state + + tp: TaskPrefix + prefix_key = key_split(key) + tp = self.state._task_prefixes.get(prefix_key) + if tp is None: + self.state._task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) + ts._prefix = tp + + tg: TaskGroup + group_key = ts._group_key + tg = self.state._task_groups.get(group_key) + if tg is None: + self.state._task_groups[group_key] = tg = TaskGroup(group_key) + tg._prefix = tp + tp._groups.append(tg) + tg.add(ts) + + self.state._tasks[key] = ts + + return ts + ################## # Administration # ################## From a4cdb66f22a84d1f5daa82920f530ab581a0fead Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Thu, 5 Aug 2021 11:56:28 -0400 Subject: [PATCH 03/10] First working SchedulerState (client.submit().result() worked) --- distributed/scheduler.py | 302 ++++++++++++++++++--------------------- 1 file changed, 141 insertions(+), 161 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 6f23d7c557..bf390b1a19 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1888,6 +1888,7 @@ class SchedulerState: _idle: object _idle_dv: dict _n_tasks: Py_ssize_t + _plugins: list _resources: dict _saturated: set _tasks: dict @@ -1896,6 +1897,7 @@ class SchedulerState: _task_metadata: dict _total_nthreads: Py_ssize_t _total_occupancy: double + _transition_log: deque _transitions_table: dict _unknown_durations: dict _unrunnable: set @@ -1961,6 +1963,9 @@ def __init__( self._task_metadata = dict() self._total_nthreads = 0 self._total_occupancy = 0 + self._transition_log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) self._transitions_table = { ("released", "waiting"): self.transition_released_waiting, ("waiting", "released"): self.transition_waiting_released, @@ -2011,8 +2016,6 @@ def __init__( ) self._transition_counter = 0 - super().__init__(**kwargs) - @property def aliases(self): return self._aliases @@ -2045,6 +2048,14 @@ def idle(self): def n_tasks(self): return self._n_tasks + @property + def plugins(self): + return self._plugins + + @plugins.setter + def plugins(self, val): + self._plugins = val + @property def resources(self): return self._resources @@ -2081,6 +2092,10 @@ def total_occupancy(self): def total_occupancy(self, v: double): self._total_occupancy = v + @property + def transition_log(self): + return self._transition_log + @property def transition_counter(self): return self._transition_counter @@ -2166,6 +2181,7 @@ def new_task( # State Transitions # ##################### + # cannot be @ccall with args/kwargs def _transition(self, key, finish: str, *args, **kwargs): """Transition a key from its current state to the finish state @@ -2182,7 +2198,6 @@ def _transition(self, key, finish: str, *args, **kwargs): -------- Scheduler.transitions : transitive version of this function """ - parent: SchedulerState = cast(SchedulerState, self) ts: TaskState start: str start_finish: tuple @@ -2199,7 +2214,7 @@ def _transition(self, key, finish: str, *args, **kwargs): worker_msgs = {} client_msgs = {} - ts = parent._tasks.get(key) + ts = self._tasks.get(key) if ts is None: return recommendations, client_msgs, worker_msgs start = ts._state @@ -2265,8 +2280,8 @@ def _transition(self, key, finish: str, *args, **kwargs): raise RuntimeError("Impossible transition from %r to %r" % start_finish) finish2 = ts._state - self.transition_log.append((key, start, finish2, recommendations, time())) - if parent._validate: + self._transition_log.append((key, start, finish2, recommendations, time())) + if self._validate: logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -2280,17 +2295,17 @@ def _transition(self, key, finish: str, *args, **kwargs): if ts._state == "forgotten": ts._dependents = dependents ts._dependencies = dependencies - parent._tasks[ts._key] = ts + self._tasks[ts._key] = ts for plugin in list(self.plugins): try: plugin.transition(key, start, finish2, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts._state == "forgotten": - del parent._tasks[ts._key] + del self._tasks[ts._key] tg: TaskGroup = ts._group - if ts._state == "forgotten" and tg._name in parent._task_groups: + if ts._state == "forgotten" and tg._name in self._task_groups: # Remove TaskGroup if all tasks are in the forgotten state all_forgotten: bint = True for s in ALL_TASK_STATES: @@ -2299,7 +2314,7 @@ def _transition(self, key, finish: str, *args, **kwargs): break if all_forgotten: ts._prefix._groups.remove(tg) - del parent._task_groups[tg._name] + del self._task_groups[tg._name] return recommendations, client_msgs, worker_msgs except Exception: @@ -2316,7 +2331,6 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di This includes feedback from previous transitions and continues until we reach a steady state """ - parent: SchedulerState = cast(SchedulerState, self) keys: set = set() recommendations = recommendations.copy() msgs: list @@ -2346,7 +2360,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di else: worker_msgs[w] = new_msgs - if parent._validate: + if self._validate: for key in keys: self.validate_key(key) @@ -3686,10 +3700,6 @@ def __init__( aliases, ] - self.plugins = list(plugins) - self.transition_log = deque( - maxlen=dask.config.get("distributed.scheduler.transition-log-length") - ) self.log = deque( maxlen=dask.config.get("distributed.scheduler.transition-log-length") ) @@ -3783,12 +3793,6 @@ def __init__( connection_limit = get_fileno_limit() / 2 self.state: SchedulerState = SchedulerState( aliases=aliases, - handlers=self.handlers, - stream_handlers=merge(worker_handlers, client_handlers), - io_loop=self.loop, - connection_limit=connection_limit, - deserialize=False, - connection_args=self.connection_args, clients=clients, workers=workers, host_info=host_info, @@ -3796,9 +3800,8 @@ def __init__( tasks=tasks, unrunnable=unrunnable, validate=validate, - **kwargs, ) - + self.plugins = list(plugins) if self.worker_ttl: pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl) self.periodic_callbacks["worker-ttl"] = pc @@ -3811,6 +3814,18 @@ def __init__( extensions = list(DEFAULT_EXTENSIONS) if dask.config.get("distributed.scheduler.work-stealing"): extensions.append(WorkStealing) + + # Set up ServerNode stuff + super().__init__( + handlers=self.handlers, + stream_handlers=merge(worker_handlers, client_handlers), + io_loop=self.loop, + connection_limit=connection_limit, + deserialize=False, + connection_args=self.connection_args, + **kwargs, + ) + for ext in extensions: ext(self) @@ -3825,91 +3840,99 @@ def __init__( @property def aliases(self): - return self.state._aliases + return self.state.aliases @property def bandwidth(self): - return self.state._bandwidth + return self.state.bandwidth @property def clients(self): - return self.state._clients + return self.state.clients @property def extensions(self): - return self.state._extensions + return self.state.extensions @property def host_info(self): - return self.state._host_info + return self.state.host_info @property def idle(self): - return self.state._idle + return self.state.idle @property def n_tasks(self): - return self.state._n_tasks + return self.state.n_tasks + + @property + def plugins(self): + return self.state.plugins + + @plugins.setter + def plugins(self, val): + self.state.plugins = val @property def resources(self): - return self.state._resources + return self.state.resources @property def saturated(self): - return self.state._saturated + return self.state.saturated @property def tasks(self): - return self.state._tasks + return self.state.tasks @property def task_groups(self): - return self.state._task_groups + return self.state.task_groups @property def task_prefixes(self): - return self.state._task_prefixes + return self.state.task_prefixes @property def task_metadata(self): - return self.state._task_metadata + return self.state.task_metadata @property def total_nthreads(self): - return self.state._total_nthreads + return self.state.total_nthreads @property def total_occupancy(self): - return self.state._total_occupancy + return self.state.total_occupancy @total_occupancy.setter def total_occupancy(self, v: double): - self.state._total_occupancy = v + self.state.total_occupancy = v @property def transition_counter(self): - return self.state._transition_counter + return self.state.transition_counter @property def unknown_durations(self): - return self.state._unknown_durations + return self.state.unknown_durations @property def unrunnable(self): - return self.state._unrunnable + return self.state.unrunnable @property def validate(self): - return self.state._validate + return self.state.validate @validate.setter def validate(self, v: bint): - self.state._validate = v + self.state.validate = v @property def workers(self): - return self.state._workers + return self.state.workers @property def memory(self) -> MemoryState: @@ -3917,52 +3940,7 @@ def memory(self) -> MemoryState: @property def __pdict__(self): - return { - "bandwidth": self.state._bandwidth, - "resources": self.state._resources, - "saturated": self.state._saturated, - "unrunnable": self.state._unrunnable, - "n_tasks": self.state._n_tasks, - "unknown_durations": self.state._unknown_durations, - "validate": self.state._validate, - "tasks": self.state._tasks, - "task_groups": self.state._task_groups, - "task_prefixes": self.state._task_prefixes, - "total_nthreads": self.state._total_nthreads, - "total_occupancy": self.state._total_occupancy, - "extensions": self.state._extensions, - "clients": self.state._clients, - "workers": self.state._workers, - "idle": self.state._idle, - "host_info": self.state._host_info, - } - - @ccall - @exceptval(check=False) - def new_task(self, key: str, spec: object, state: str) -> TaskState: - """Create a new task, and associated states""" - ts: TaskState = TaskState(key, spec) - ts._state = state - - tp: TaskPrefix - prefix_key = key_split(key) - tp = self.state._task_prefixes.get(prefix_key) - if tp is None: - self.state._task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) - ts._prefix = tp - - tg: TaskGroup - group_key = ts._group_key - tg = self.state._task_groups.get(group_key) - if tg is None: - self.state._task_groups[group_key] = tg = TaskGroup(group_key) - tg._prefix = tp - tp._groups.append(tg) - tg.add(ts) - - self.state._tasks[key] = ts - - return ts + return self.state.__pdict__ ################## # Administration # @@ -4099,7 +4077,7 @@ async def close(self, comm=None, fast=False, close_workers=False): -------- Scheduler.cleanup """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if self.status in (Status.closing, Status.closed, Status.closing_gracefully): await self.finished() return @@ -4190,7 +4168,7 @@ def heartbeat_worker( metrics: dict, executing: dict = None, ): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state address = self.coerce_address(address, resolve_address) address = normalize_address(address) ws: WorkerState = parent._workers_dv.get(address) @@ -4302,7 +4280,7 @@ async def add_worker( extra=None, ): """Add a new worker to the cluster""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state with log_errors(): address = self.coerce_address(address, resolve_address) address = normalize_address(address) @@ -4363,7 +4341,7 @@ async def add_worker( ) # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot exist before this. - self.check_idle_saturated(ws) + self.state.check_idle_saturated(ws) # for key in keys: # TODO # self.mark_key_in_memory(key, [address]) @@ -4535,7 +4513,7 @@ def update_graph( This happens whenever the Client calls submit, map, get, or compute. """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state start = time() fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) @@ -4806,7 +4784,7 @@ def update_graph( def stimulus_task_finished(self, key=None, worker=None, **kwargs): """Mark that a task has finished execution on a particular worker""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state logger.debug("Stimulus task finished %s, %s", key, worker) recommendations: dict = {} @@ -4850,7 +4828,7 @@ def stimulus_task_erred( self, key=None, worker=None, exception=None, traceback=None, **kwargs ): """Mark that a task has erred on a particular worker""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state logger.debug("Stimulus task erred %s, %s", key, worker) recommendations: dict = {} @@ -4885,7 +4863,7 @@ def stimulus_missing_data( self, cause=None, key=None, worker=None, ensure=True, **kwargs ): """Mark that certain keys have gone missing. Recover.""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state with log_errors(): logger.debug("Stimulus missing data %s, %s", key, worker) @@ -4919,7 +4897,7 @@ def stimulus_missing_data( return recommendations, client_msgs, worker_msgs def stimulus_retry(self, comm=None, keys=None, client=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -4956,7 +4934,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): appears to be unresponsive. This may send its tasks back to a released state. """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state with log_errors(): if self.status == Status.closed: return @@ -5083,7 +5061,7 @@ def stimulus_cancel(self, comm, keys=None, client=None, force=False): def cancel_key(self, key, client, retries=5, force=False): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState = parent._tasks.get(key) dts: TaskState try: @@ -5106,7 +5084,7 @@ def cancel_key(self, key, client, retries=5, force=False): self.client_releases_keys(keys=[key], client=cs._client_key) def client_desires_keys(self, keys=None, client=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state cs: ClientState = parent._clients.get(client) if cs is None: # For publish, queues etc. @@ -5126,7 +5104,7 @@ def client_desires_keys(self, keys=None, client=None): def client_releases_keys(self, keys=None, client=None): """Remove keys from client desired list""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if not isinstance(keys, list): keys = list(keys) cs: ClientState = parent._clients[client] @@ -5137,7 +5115,7 @@ def client_releases_keys(self, keys=None, client=None): def client_heartbeat(self, client=None): """Handle heartbeats from Client""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state cs: ClientState = parent._clients[client] cs._last_seen = time() @@ -5146,7 +5124,7 @@ def client_heartbeat(self, client=None): ################### def validate_released(self, key): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState = parent._tasks[key] dts: TaskState assert ts._state == "released" @@ -5158,7 +5136,7 @@ def validate_released(self, key): assert ts not in parent._unrunnable def validate_waiting(self, key): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState = parent._tasks[key] dts: TaskState assert ts._waiting_on @@ -5171,7 +5149,7 @@ def validate_waiting(self, key): assert ts in dts._waiters # XXX even if dts._who_has? def validate_processing(self, key): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState = parent._tasks[key] dts: TaskState assert not ts._waiting_on @@ -5184,7 +5162,7 @@ def validate_processing(self, key): assert ts in dts._waiters def validate_memory(self, key): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState = parent._tasks[key] dts: TaskState assert ts._who_has @@ -5196,7 +5174,7 @@ def validate_memory(self, key): assert ts not in dts._waiting_on def validate_no_worker(self, key): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState = parent._tasks[key] dts: TaskState assert ts in parent._unrunnable @@ -5208,13 +5186,13 @@ def validate_no_worker(self, key): assert dts._who_has def validate_erred(self, key): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState = parent._tasks[key] assert ts._exception_blame assert not ts._who_has def validate_key(self, key, ts: TaskState = None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state try: if ts is None: ts = parent._tasks.get(key) @@ -5239,7 +5217,7 @@ def validate_key(self, key, ts: TaskState = None): raise def validate_state(self, allow_overlap=False): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state validate_state(parent._tasks, parent._workers, parent._clients) if not (set(parent._workers_dv) == set(self.stream_comms)): @@ -5296,7 +5274,7 @@ def report(self, msg: dict, ts: TaskState = None, client: str = None): If the message contains a key then we only send the message to those comms that care about the key. """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if ts is None: msg_key = msg.get("key") if msg_key is not None: @@ -5336,7 +5314,7 @@ async def add_client(self, comm, client=None, versions=None): We listen to all future messages from this Comm. """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state assert client is not None comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) @@ -5382,7 +5360,7 @@ async def add_client(self, comm, client=None, versions=None): def remove_client(self, client=None): """Remove client from network""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if self.status == Status.running: logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) @@ -5416,7 +5394,7 @@ def remove_client_from_events(): def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): """Send a single computational task to a worker""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state try: msg: dict = _task_to_msg(parent, ts, duration) self.worker_send(worker, msg) @@ -5432,7 +5410,7 @@ def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) def handle_task_finished(self, key=None, worker=None, **msg): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if worker not in parent._workers_dv: return validate_key(key) @@ -5448,7 +5426,7 @@ def handle_task_finished(self, key=None, worker=None, **msg): self.send_all(client_msgs, worker_msgs) def handle_task_erred(self, key=None, **msg): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state recommendations: dict client_msgs: dict worker_msgs: dict @@ -5459,7 +5437,7 @@ def handle_task_erred(self, key=None, **msg): self.send_all(client_msgs, worker_msgs) def handle_release_data(self, key=None, worker=None, client=None, **msg): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState = parent._tasks.get(key) if ts is None: return @@ -5478,7 +5456,7 @@ def handle_release_data(self, key=None, worker=None, client=None, **msg): self.send_all(client_msgs, worker_msgs) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state logger.debug("handle missing data key=%s worker=%s", key, errant_worker) self.log.append(("missing", key, errant_worker)) @@ -5497,7 +5475,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): self.transitions({key: "forgotten"}) def release_worker_data(self, comm=None, keys=None, worker=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState = parent._workers_dv[worker] tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks} removed_tasks: set = tasks.intersection(ws._has_what) @@ -5520,7 +5498,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): We stop the task from being stolen in the future, and change task duration accounting as if the task has stopped. """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if key not in parent._tasks: logger.debug("Skipping long_running since key %s was already released", key) return @@ -5549,7 +5527,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): ws._occupancy -= occ parent._total_occupancy -= occ ws._processing[ts] = 0 - self.check_idle_saturated(ws) + self.state.check_idle_saturated(ws) async def handle_worker(self, comm=None, worker=None): """ @@ -5674,7 +5652,7 @@ async def scatter( -------- Scheduler.broadcast: """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state start = time() while not parent._workers_dv: await asyncio.sleep(0.2) @@ -5710,7 +5688,7 @@ async def scatter( async def gather(self, comm=None, keys=None, serializers=None): """Collect data in from workers""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState keys = list(keys) who_has = {} @@ -5786,7 +5764,7 @@ def clear_task_state(self): async def restart(self, client=None, timeout=3): """Restart all workers. Reset local state.""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state with log_errors(): n_workers = len(parent._workers_dv) @@ -5874,7 +5852,7 @@ async def broadcast( serializers=None, ): """Broadcast message to workers, return all results""" - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if workers is None or workers is True: if hosts is None: workers = list(parent._workers_dv) @@ -5924,7 +5902,7 @@ async def _delete_worker_data(self, worker_address, keys): keys: List[str] List of keys to delete on the specified worker """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state await retry_operation( self.rpc(addr=worker_address).free_keys, @@ -6333,7 +6311,7 @@ async def replicate( -------- Scheduler.rebalance """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState wws: WorkerState ts: TaskState @@ -6484,7 +6462,7 @@ def workers_to_close( -------- Scheduler.retire_workers """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if target is not None and n is None: n = len(parent._workers_dv) - target if n is not None: @@ -6592,7 +6570,7 @@ async def retire_workers( -------- Scheduler.workers_to_close """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState ts: TaskState with log_errors(): @@ -6676,7 +6654,7 @@ def add_keys(self, comm=None, worker=None, keys=()): This should not be used in practice and is mostly here for legacy reasons. However, it is sent by workers from time to time. """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if worker not in parent._workers_dv: return "not found" ws: WorkerState = parent._workers_dv[worker] @@ -6717,7 +6695,7 @@ def update_data( -------- Scheduler.mark_key_in_memory """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state with log_errors(): who_has = { k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() @@ -6748,7 +6726,7 @@ def update_data( self.client_desires_keys(keys=list(who_has), client=client) def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if ts is None: ts = parent._tasks.get(key) elif key is None: @@ -6820,7 +6798,7 @@ def subscribe_worker_status(self, comm=None): return ident def get_processing(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState ts: TaskState if workers is not None: @@ -6835,7 +6813,7 @@ def get_processing(self, comm=None, workers=None): } def get_who_has(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState ts: TaskState if keys is not None: @@ -6852,7 +6830,7 @@ def get_who_has(self, comm=None, keys=None): } def get_has_what(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState ts: TaskState if workers is not None: @@ -6870,7 +6848,7 @@ def get_has_what(self, comm=None, workers=None): } def get_ncores(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState if workers is not None: workers = map(self.coerce_address, workers) @@ -6883,7 +6861,7 @@ def get_ncores(self, comm=None, workers=None): return {w: ws._nthreads for w, ws in parent._workers_dv.items()} async def get_call_stack(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState dts: TaskState if keys is not None: @@ -6914,7 +6892,7 @@ async def get_call_stack(self, comm=None, keys=None): return response def get_nbytes(self, comm=None, keys=None, summary=True): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState with log_errors(): if keys is not None: @@ -6952,7 +6930,7 @@ def run_function(self, stream, function, args=(), kwargs={}, wait=True): return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) def set_metadata(self, comm=None, keys=None, value=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state try: metadata = parent._task_metadata for key in keys[:-1]: @@ -6966,7 +6944,7 @@ def set_metadata(self, comm=None, keys=None, value=None): pdb.set_trace() def get_metadata(self, comm=None, keys=None, default=no_default): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state metadata = parent._task_metadata for key in keys[:-1]: metadata = metadata[key] @@ -6983,7 +6961,7 @@ def set_restrictions(self, comm=None, worker=None): self.tasks[key]._worker_restrictions = set(restrictions) def get_task_status(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state return { key: (parent._tasks[key].state if key in parent._tasks else None) for key in keys @@ -7074,7 +7052,7 @@ def transition(self, key, finish: str, *args, **kwargs): -------- Scheduler.transitions: transitive version of this function """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state recommendations: dict worker_msgs: dict client_msgs: dict @@ -7089,7 +7067,7 @@ def transitions(self, recommendations: dict): This includes feedback from previous transitions and continues until we reach a steady state """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state client_msgs: dict = {} worker_msgs: dict = {} parent._transitions(recommendations, client_msgs, worker_msgs) @@ -7099,7 +7077,9 @@ def story(self, *keys): """Get all transitions that touch one of the input keys""" keys = {key.key if isinstance(key, TaskState) else key for key in keys} return [ - t for t in self.transition_log if t[0] in keys or keys.intersection(t[3]) + t + for t in self.state.transition_log + if t[0] in keys or keys.intersection(t[3]) ] transition_story = story @@ -7110,7 +7090,7 @@ def reschedule(self, key=None, worker=None): Things may have shifted and this task may now be better suited to run elsewhere """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ts: TaskState try: ts = parent._tasks[key] @@ -7131,7 +7111,7 @@ def reschedule(self, key=None, worker=None): ##################### def add_resources(self, comm=None, worker=None, resources=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState = parent._workers_dv[worker] if resources: ws._resources.update(resources) @@ -7145,7 +7125,7 @@ def add_resources(self, comm=None, worker=None, resources=None): return "OK" def remove_resources(self, worker): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState = parent._workers_dv[worker] for resource, quantity in ws._resources.items(): dr: dict = parent._resources.get(resource, None) @@ -7161,7 +7141,7 @@ def coerce_address(self, addr, resolve=True): Handles strings, tuples, or aliases. """ # XXX how many address-parsing routines do we have? - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if addr in parent._aliases: addr = parent._aliases[addr] if isinstance(addr, tuple): @@ -7183,7 +7163,7 @@ def workers_list(self, workers): Takes a list of worker addresses or hostnames. Returns a list of all worker addresses that match """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if workers is None: return list(parent._workers) @@ -7219,7 +7199,7 @@ async def get_profile( stop=None, key=None, ): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if workers is None: workers = parent._workers_dv else: @@ -7253,7 +7233,7 @@ async def get_profile_metadata( stop=None, profile_cycle_interval=None, ): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state dt = profile_cycle_interval or dask.config.get( "distributed.worker.profile.cycle" ) @@ -7296,7 +7276,7 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} async def performance_report(self, comm=None, start=None, last_count=None, code=""): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state stop = time() # Profiles compute, scheduler, workers = await asyncio.gather( @@ -7472,7 +7452,7 @@ def get_events(self, comm=None, topic=None): return valmap(tuple, self.events) async def get_worker_monitor_info(self, recent=False, starts=None): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if starts is None: starts = {} results = await asyncio.gather( @@ -7502,7 +7482,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): lets us avoid this fringe optimization when we have better things to think about. """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state try: if self.status == Status.closed: return @@ -7538,7 +7518,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): raise async def check_worker_ttl(self): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState now = time() for ws in parent._workers_dv.values(): @@ -7553,7 +7533,7 @@ async def check_worker_ttl(self): await self.remove_worker(address=ws._address) def check_idle(self): - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state ws: WorkerState if ( any([ws._processing for ws in parent._workers_dv.values()]) @@ -7587,7 +7567,7 @@ def adaptive_target(self, comm=None, target_duration=None): -------- distributed.deploy.Adaptive """ - parent: SchedulerState = cast(SchedulerState, self) + parent: SchedulerState = self.state if target_duration is None: target_duration = dask.config.get("distributed.adaptive.target-duration") target_duration = parse_timedelta(target_duration) From 51f489fdd8ecf17b6b752600012841adc7f3750f Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Mon, 9 Aug 2021 11:28:19 -0400 Subject: [PATCH 04/10] Works again, with transitions as ccall --- distributed/scheduler.py | 76 ++++++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 22 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bf390b1a19..5246f22456 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2181,11 +2181,11 @@ def new_task( # State Transitions # ##################### - # cannot be @ccall with args/kwargs - def _transition(self, key, finish: str, *args, **kwargs): + @ccall + def _transition(self, key, finish: str, args: tuple = None, kwargs: dict = None): """Transition a key from its current state to the finish state - Examples + Examples.stimulus_task_finished -------- >>> self._transition('x', 'waiting') {'x': 'processing'} @@ -2209,6 +2209,8 @@ def _transition(self, key, finish: str, *args, **kwargs): new_msgs: list dependents: set dependencies: set + args: tuple = args or () + kwargs: dict = kwargs or {} try: recommendations = {} worker_msgs = {} @@ -2228,7 +2230,10 @@ def _transition(self, key, finish: str, *args, **kwargs): start_finish = (start, finish) func = self._transitions_table.get(start_finish) if func is not None: - a: tuple = func(key, *args, **kwargs) + if finish in ["forgotten", "waiting", "released", "processing"]: + a: tuple = func(key) + else: + a: tuple = func(key, args, kwargs) self._transition_counter += 1 recommendations, client_msgs, worker_msgs = a elif "released" not in start_finish: @@ -2325,6 +2330,7 @@ def _transition(self, key, finish: str, *args, **kwargs): pdb.set_trace() raise + @ccall def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): """Process transitions until none are left @@ -2583,6 +2589,7 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: ws._processing[ts] = total_duration return total_duration + @ccall def transition_waiting_processing(self, key): try: ts: TaskState = self._tasks[key] @@ -2630,8 +2637,9 @@ def transition_waiting_processing(self, key): pdb.set_trace() raise + @ccall def transition_waiting_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs + self, key, nbytes=None, type=None, typename: str = None, worker=None ): try: ws: WorkerState = self._workers_dv[worker] @@ -2670,16 +2678,18 @@ def transition_waiting_memory( pdb.set_trace() raise + @ccall def transition_processing_memory( self, key, - nbytes=None, - type=None, - typename: str = None, - worker=None, - startstops=None, - **kwargs, + _, + kwargs, ): + nbytes = kwargs.get("nbytes") + type = kwargs.get("type") + typename: str = kwargs.get("typename") + worker = kwargs.get("worker") + startstops = kwargs.get("startstops") ws: WorkerState wws: WorkerState recommendations: dict = {} @@ -2797,6 +2807,7 @@ def transition_processing_memory( pdb.set_trace() raise + @ccall def transition_memory_released(self, key, safe: bint = False): ws: WorkerState try: @@ -2871,6 +2882,7 @@ def transition_memory_released(self, key, safe: bint = False): pdb.set_trace() raise + @ccall def transition_released_erred(self, key): try: ts: TaskState = self._tasks[key] @@ -2916,6 +2928,7 @@ def transition_released_erred(self, key): pdb.set_trace() raise + @ccall def transition_erred_released(self, key): try: ts: TaskState = self._tasks[key] @@ -2960,6 +2973,7 @@ def transition_erred_released(self, key): pdb.set_trace() raise + @ccall def transition_waiting_released(self, key): try: ts: TaskState = self._tasks[key] @@ -2997,6 +3011,7 @@ def transition_waiting_released(self, key): pdb.set_trace() raise + @ccall def transition_processing_released(self, key): try: ts: TaskState = self._tasks[key] @@ -3044,8 +3059,9 @@ def transition_processing_released(self, key): pdb.set_trace() raise + @ccall def transition_processing_erred( - self, key, cause=None, exception=None, traceback=None, worker=None, **kwargs + self, key, cause=None, exception=None, traceback=None, worker=None ): ws: WorkerState try: @@ -3123,6 +3139,7 @@ def transition_processing_erred( pdb.set_trace() raise + @ccall def transition_no_worker_released(self, key): try: ts: TaskState = self._tasks[key] @@ -3166,6 +3183,7 @@ def remove_key(self, key): ts._exception_blame = ts._exception = ts._traceback = None self._task_metadata.pop(key, None) + @ccall def transition_memory_forgotten(self, key): ws: WorkerState try: @@ -4373,9 +4391,11 @@ async def add_worker( t: tuple = parent._transition( key, "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], + kwargs=dict( + worker=address, + nbytes=nbytes[key], + typename=types[key], + ), ) recommendations, client_msgs, worker_msgs = t parent._transitions( @@ -4803,7 +4823,17 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): ts._metadata.update(kwargs["metadata"]) if ts._state != "released": - r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) + r: tuple = parent._transition( + key, + "memory", + kwargs=dict( + worker=worker, + nbytes=kwargs.get("nbytes"), + type=kwargs.get("type"), + typename=kwargs.get("typename"), + startstops=kwargs.get("startstops"), + ), + ) recommendations, client_msgs, worker_msgs = r if ts._state == "memory": @@ -4849,11 +4879,13 @@ def stimulus_task_erred( r = parent._transition( key, "erred", - cause=key, - exception=exception, - traceback=traceback, - worker=worker, - **kwargs, + kwargs=dict( + cause=key, + exception=exception, + traceback=traceback, + worker=worker, + **kwargs, + ), ) recommendations, client_msgs, worker_msgs = r @@ -7056,7 +7088,7 @@ def transition(self, key, finish: str, *args, **kwargs): recommendations: dict worker_msgs: dict client_msgs: dict - a: tuple = parent._transition(key, finish, *args, **kwargs) + a: tuple = parent._transition(key, finish, args=args, kwargs=kwargs) recommendations, client_msgs, worker_msgs = a self.send_all(client_msgs, worker_msgs) return recommendations From 47686f12607ee876376e266fd394a8f1060171c6 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Mon, 9 Aug 2021 11:53:33 -0400 Subject: [PATCH 05/10] more ccall --- distributed/scheduler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5246f22456..af579787f5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2181,7 +2181,7 @@ def new_task( # State Transitions # ##################### - @ccall + @cfunc def _transition(self, key, finish: str, args: tuple = None, kwargs: dict = None): """Transition a key from its current state to the finish state @@ -2370,6 +2370,7 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di for key in keys: self.validate_key(key) + @ccall def transition_released_waiting(self, key): try: ts: TaskState = self._tasks[key] @@ -2425,6 +2426,7 @@ def transition_released_waiting(self, key): pdb.set_trace() raise + @ccall def transition_no_worker_waiting(self, key): try: ts: TaskState = self._tasks[key] @@ -3226,6 +3228,7 @@ def transition_memory_forgotten(self, key): pdb.set_trace() raise + @ccall def transition_released_forgotten(self, key): try: ts: TaskState = self._tasks[key] @@ -4391,6 +4394,7 @@ async def add_worker( t: tuple = parent._transition( key, "memory", + args=None, kwargs=dict( worker=address, nbytes=nbytes[key], @@ -4826,6 +4830,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): r: tuple = parent._transition( key, "memory", + args=None, kwargs=dict( worker=worker, nbytes=kwargs.get("nbytes"), @@ -4879,6 +4884,7 @@ def stimulus_task_erred( r = parent._transition( key, "erred", + args=None, kwargs=dict( cause=key, exception=exception, From a9030e4479d31a261be4166c9fe6ce0daddcc9cb Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Mon, 9 Aug 2021 12:04:17 -0400 Subject: [PATCH 06/10] passes a bunch of tests --- distributed/scheduler.py | 52 +++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index af579787f5..ad85417333 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3361,7 +3361,7 @@ def valid_workers(self, ts: TaskState) -> set: if ts._host_restrictions: # Resolve the alias here rather than early, for the worker # may not be connected when host_restrictions is populated - hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions] + hr: list = [self.state.coerce_hostname(h) for h in ts._host_restrictions] # XXX need HostState? sl: list = [] for h in hr: @@ -3448,6 +3448,32 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: else: return (start_time, ws._nbytes) + @ccall + def validate_key(self, key, ts: TaskState = None): + parent: SchedulerState = self + try: + if ts is None: + ts = parent._tasks.get(key) + if ts is None: + logger.debug("Key lost: %s", key) + else: + ts.validate() + try: + func = getattr(self, "validate_" + ts._state.replace("-", "_")) + except AttributeError: + logger.error( + "self.validate_%s not found", ts._state.replace("-", "_") + ) + else: + func(key) + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + class Scheduler(ServerNode): """Dynamic distributed task scheduler @@ -5230,29 +5256,7 @@ def validate_erred(self, key): assert not ts._who_has def validate_key(self, key, ts: TaskState = None): - parent: SchedulerState = self.state - try: - if ts is None: - ts = parent._tasks.get(key) - if ts is None: - logger.debug("Key lost: %s", key) - else: - ts.validate() - try: - func = getattr(self, "validate_" + ts._state.replace("-", "_")) - except AttributeError: - logger.error( - "self.validate_%s not found", ts._state.replace("-", "_") - ) - else: - func(key) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise + self.state.validate_key(key, ts) def validate_state(self, allow_overlap=False): parent: SchedulerState = self.state From c88725037fc959ee9221317dd018cd3dff71db54 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Mon, 9 Aug 2021 14:39:59 -0400 Subject: [PATCH 07/10] Move validations to SchedulerState --- distributed/scheduler.py | 267 ++++++++++++++-------------- distributed/tests/test_scheduler.py | 46 ++--- 2 files changed, 156 insertions(+), 157 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ad85417333..5d93bc1d01 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2640,9 +2640,11 @@ def transition_waiting_processing(self, key): raise @ccall - def transition_waiting_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None - ): + def transition_waiting_memory(self, key, _, kwargs): + nbytes = kwargs.get("nbytes") + type = kwargs.get("type") + typename: str = kwargs.get("typename") + worker = kwargs.get("worker") try: ws: WorkerState = self._workers_dv[worker] ts: TaskState = self._tasks[key] @@ -3062,10 +3064,13 @@ def transition_processing_released(self, key): raise @ccall - def transition_processing_erred( - self, key, cause=None, exception=None, traceback=None, worker=None - ): + def transition_processing_erred(self, key, _, kwargs): + cause = kwargs.get("cause") + exception = kwargs.get("exception") + traceback = kwargs.get("traceback") + worker = kwargs.get("worker") ws: WorkerState + print(key, cause, exception, worker) try: ts: TaskState = self._tasks[key] dts: TaskState @@ -3448,12 +3453,76 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: else: return (start_time, ws._nbytes) - @ccall + ################### + # Task Validation # + ################### + # TODO: could all be @ccall, but called rarely + + def validate_released(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._state == "released" + assert not ts._waiters + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([ts in dts._waiters for dts in ts._dependencies]) + assert ts not in self._unrunnable + + def validate_waiting(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert ts not in self._unrunnable + for dts in ts._dependencies: + # We are waiting on a dependency iff it's not stored + assert (not not dts._who_has) != (dts in ts._waiting_on) + assert ts in dts._waiters # XXX even if dts._who_has? + + def validate_processing(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert not ts._waiting_on + ws: WorkerState = ts._processing_on + assert ws + assert ts in ws._processing + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + assert ts in dts._waiters + + def validate_memory(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._who_has + assert not ts._processing_on + assert not ts._waiting_on + assert ts not in self._unrunnable + for dts in ts._dependents: + assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) + assert ts not in dts._waiting_on + + def validate_no_worker(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts in self._unrunnable + assert not ts._waiting_on + assert not ts._processing_on + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + + def validate_erred(self, key): + ts: TaskState = self._tasks[key] + assert ts._exception_blame + assert not ts._who_has + def validate_key(self, key, ts: TaskState = None): - parent: SchedulerState = self try: if ts is None: - ts = parent._tasks.get(key) + ts = self._tasks.get(key) if ts is None: logger.debug("Key lost: %s", key) else: @@ -3474,6 +3543,52 @@ def validate_key(self, key, ts: TaskState = None): pdb.set_trace() raise + def validate_state(self, allow_overlap=False): + validate_state(self._tasks, self._workers, self._clients) + + # if not (set(parent._workers_dv) == set(self.stream_comms)): + # raise ValueError("Workers not the same in all collections") + + ws: WorkerState + for w, ws in self._workers_dv.items(): + assert isinstance(w, str), (type(w), w) + assert isinstance(ws, WorkerState), (type(ws), ws) + assert ws._address == w + if not ws._processing: + assert not ws._occupancy + assert ws._address in self._idle_dv + + ts: TaskState + for k, ts in self._tasks.items(): + assert isinstance(ts, TaskState), (type(ts), ts) + assert ts._key == k + self.validate_key(k, ts) + + c: str + cs: ClientState + for c, cs in self._clients.items(): + # client=None is often used in tests... + assert c is None or type(c) == str, (type(c), c) + assert type(cs) == ClientState, (type(cs), cs) + assert cs._client_key == c + + a = {w: ws._nbytes for w, ws in self._workers_dv.items()} + b = { + w: sum(ts.get_nbytes() for ts in ws._has_what) + for w, ws in self._workers_dv.items() + } + assert a == b, (a, b) + + actual_total_occupancy = 0 + for worker, ws in self._workers_dv.items(): + assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 + actual_total_occupancy += ws._occupancy + + assert abs(actual_total_occupancy - self._total_occupancy) < 1e-8, ( + actual_total_occupancy, + self._total_occupancy, + ) + class Scheduler(ServerNode): """Dynamic distributed task scheduler @@ -4434,7 +4549,7 @@ async def add_worker( recommendations = {} for ts in list(parent._unrunnable): - valid: set = self.valid_workers(ts) + valid: set = self.state.valid_workers(ts) if valid is None or ws in valid: recommendations[ts._key] = "waiting" @@ -5183,127 +5298,8 @@ def client_heartbeat(self, client=None): cs: ClientState = parent._clients[client] cs._last_seen = time() - ################### - # Task Validation # - ################### - - def validate_released(self, key): - parent: SchedulerState = self.state - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._state == "released" - assert not ts._waiters - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert not any([ts in dts._waiters for dts in ts._dependencies]) - assert ts not in parent._unrunnable - - def validate_waiting(self, key): - parent: SchedulerState = self.state - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert ts not in parent._unrunnable - for dts in ts._dependencies: - # We are waiting on a dependency iff it's not stored - assert (not not dts._who_has) != (dts in ts._waiting_on) - assert ts in dts._waiters # XXX even if dts._who_has? - - def validate_processing(self, key): - parent: SchedulerState = self.state - ts: TaskState = parent._tasks[key] - dts: TaskState - assert not ts._waiting_on - ws: WorkerState = ts._processing_on - assert ws - assert ts in ws._processing - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has - assert ts in dts._waiters - - def validate_memory(self, key): - parent: SchedulerState = self.state - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._who_has - assert not ts._processing_on - assert not ts._waiting_on - assert ts not in parent._unrunnable - for dts in ts._dependents: - assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) - assert ts not in dts._waiting_on - - def validate_no_worker(self, key): - parent: SchedulerState = self.state - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts in parent._unrunnable - assert not ts._waiting_on - assert ts in parent._unrunnable - assert not ts._processing_on - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has - - def validate_erred(self, key): - parent: SchedulerState = self.state - ts: TaskState = parent._tasks[key] - assert ts._exception_blame - assert not ts._who_has - - def validate_key(self, key, ts: TaskState = None): - self.state.validate_key(key, ts) - - def validate_state(self, allow_overlap=False): - parent: SchedulerState = self.state - validate_state(parent._tasks, parent._workers, parent._clients) - - if not (set(parent._workers_dv) == set(self.stream_comms)): - raise ValueError("Workers not the same in all collections") - - ws: WorkerState - for w, ws in parent._workers_dv.items(): - assert isinstance(w, str), (type(w), w) - assert isinstance(ws, WorkerState), (type(ws), ws) - assert ws._address == w - if not ws._processing: - assert not ws._occupancy - assert ws._address in parent._idle_dv - - ts: TaskState - for k, ts in parent._tasks.items(): - assert isinstance(ts, TaskState), (type(ts), ts) - assert ts._key == k - self.validate_key(k, ts) - - c: str - cs: ClientState - for c, cs in parent._clients.items(): - # client=None is often used in tests... - assert c is None or type(c) == str, (type(c), c) - assert type(cs) == ClientState, (type(cs), cs) - assert cs._client_key == c - - a = {w: ws._nbytes for w, ws in parent._workers_dv.items()} - b = { - w: sum(ts.get_nbytes() for ts in ws._has_what) - for w, ws in parent._workers_dv.items() - } - assert a == b, (a, b) - - actual_total_occupancy = 0 - for worker, ws in parent._workers_dv.items(): - assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 - actual_total_occupancy += ws._occupancy - - assert abs(actual_total_occupancy - parent._total_occupancy) < 1e-8, ( - actual_total_occupancy, - parent._total_occupancy, - ) + def validate_state(self): + self.state.validate_state() ################### # Manage Messages # @@ -6032,7 +6028,7 @@ async def rebalance( All other workers will be ignored. The mean cluster occupancy will be calculated only using the whitelisted workers. """ - parent: SchedulerState = self + parent: SchedulerState = self.state with log_errors(): if workers is not None: @@ -6063,7 +6059,7 @@ async def rebalance( return await self._rebalance_move_data(msgs) def _rebalance_find_msgs( - self: SchedulerState, + self, keys: "Optional[Set[Hashable]]", workers: "Iterable[WorkerState]", ) -> "list[tuple[WorkerState, WorkerState, TaskState]]": @@ -6093,7 +6089,7 @@ def _rebalance_find_msgs( - recipient worker - task to be transferred """ - parent: SchedulerState = self + parent: SchedulerState = self.state ts: TaskState ws: WorkerState @@ -7961,6 +7957,7 @@ def decide_worker( return ws +@ccall def validate_task_state(ts: TaskState): """ Validate the given TaskState. @@ -8062,6 +8059,7 @@ def validate_task_state(ts: TaskState): assert ts in ts._processing_on.actors +@ccall def validate_worker_state(ws: WorkerState): ts: TaskState for ts in ws._has_what: @@ -8076,6 +8074,7 @@ def validate_worker_state(ws: WorkerState): assert ts._state in ("memory", "processing") +@ccall def validate_state(tasks, workers, clients): """ Validate a current runtime state diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 536fd3fd45..abf233888e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -715,11 +715,11 @@ async def test_coerce_address(s): assert s.coerce_address(123) == b.address assert s.coerce_address("charlie") == c.address - assert s.coerce_hostname("127.0.0.1") == "127.0.0.1" - assert s.coerce_hostname("alice") == a.ip - assert s.coerce_hostname(123) == b.ip - assert s.coerce_hostname("charlie") == c.ip - assert s.coerce_hostname("jimmy") == "jimmy" + assert s.state.coerce_hostname("127.0.0.1") == "127.0.0.1" + assert s.state.coerce_hostname("alice") == a.ip + assert s.state.coerce_hostname(123) == b.ip + assert s.state.coerce_hostname("charlie") == c.ip + assert s.state.coerce_hostname("jimmy") == "jimmy" assert s.coerce_address("zzzt:8000", resolve=False) == "tcp://zzzt:8000" await asyncio.gather(a.close(), b.close(), c.close()) @@ -777,11 +777,11 @@ async def test_story(c, s, a, b): f = c.persist(y) await wait([f]) - assert s.transition_log + assert s.state.transition_log story = s.story(x.key) - assert all(line in s.transition_log for line in story) - assert len(story) < len(s.transition_log) + assert all(line in s.state.transition_log for line in story) + assert len(story) < len(s.state.transition_log) assert all(x.key == line[0] or x.key in line[-2] for line in story) assert len(s.story(x.key, y.key)) > len(story) @@ -1542,13 +1542,13 @@ async def test_dont_recompute_if_persisted(c, s, a, b): yy = y.persist() await wait(yy) - old = list(s.transition_log) + old = list(s.state.transition_log) yyy = y.persist() await wait(yyy) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster(client=True) @@ -1579,12 +1579,12 @@ async def test_dont_recompute_if_persisted_3(c, s, a, b): ww = w.persist() await wait(ww) - old = list(s.transition_log) + old = list(s.state.transition_log) www = w.persist() await wait(www) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster(client=True) @@ -1631,13 +1631,13 @@ async def test_dont_recompute_if_erred(c, s, a, b): yy = y.persist() await wait(yy) - old = list(s.transition_log) + old = list(s.state.transition_log) yyy = y.persist() await wait(yyy) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster() @@ -3126,17 +3126,17 @@ async def test_computations(c, s, a, b): z = (x - 2).persist() await z - assert len(s.computations) == 2 - assert "add" in str(s.computations[0].groups) - assert "sub" in str(s.computations[1].groups) - assert "sub" not in str(s.computations[0].groups) + assert len(s.state.computations) == 2 + assert "add" in str(s.state.computations[0].groups) + assert "sub" in str(s.state.computations[1].groups) + assert "sub" not in str(s.state.computations[0].groups) - assert isinstance(repr(s.computations[1]), str) - assert "x + 1" in s.computations[1]._repr_html_() + assert isinstance(repr(s.state.computations[1]), str) + assert "x + 1" in s.state.computations[1]._repr_html_() - assert s.computations[1].stop == max(tg.stop for tg in s.task_groups.values()) + assert s.state.computations[1].stop == max(tg.stop for tg in s.task_groups.values()) - assert s.computations[0].states["memory"] == y.npartitions + assert s.state.computations[0].states["memory"] == y.npartitions @gen_cluster(client=True) @@ -3145,7 +3145,7 @@ async def test_computations_futures(c, s, a, b): total = c.submit(sum, futures) await total - [computation] = s.computations + [computation] = s.state.computations assert "sum" in str(computation.groups) assert "inc" in str(computation.groups) From 7767c9c343706101343665bedb207d8ba1526d73 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 10 Aug 2021 14:35:38 -0400 Subject: [PATCH 08/10] update some tests to use SchedulerState attrs --- distributed/tests/test_client.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8c1e0645ee..c89ac2dd0e 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3831,7 +3831,7 @@ async def test_idempotence(s, a, b): # Submit x = c.submit(inc, 1) await x - log = list(s.transition_log) + log = list(s.state.transition_log) len_single_submit = len(log) # see last assert @@ -3839,29 +3839,29 @@ async def test_idempotence(s, a, b): assert x.key == y.key await y await asyncio.sleep(0.1) - log2 = list(s.transition_log) + log2 = list(s.state.transition_log) assert log == log2 # Error a = c.submit(div, 1, 0) await wait(a) assert a.status == "error" - log = list(s.transition_log) + log = list(s.state.transition_log) b = f.submit(div, 1, 0) assert a.key == b.key await wait(b) await asyncio.sleep(0.1) - log2 = list(s.transition_log) + log2 = list(s.state.transition_log) assert log == log2 - s.transition_log.clear() + s.state.transition_log.clear() # Simultaneous Submit d = c.submit(inc, 2) e = c.submit(inc, 2) await wait([d, e]) - assert len(s.transition_log) == len_single_submit + assert len(s.state.transition_log) == len_single_submit await c.close() await f.close() @@ -4517,7 +4517,7 @@ def test_auto_normalize_collection_sync(c): def assert_no_data_loss(scheduler): - for key, start, finish, recommendations, _ in scheduler.transition_log: + for key, start, finish, recommendations, _ in scheduler.state.transition_log: if start == "memory" and finish == "released": for k, v in recommendations.items(): assert not (k == key and v == "waiting") @@ -5482,7 +5482,7 @@ def fib(x): future = c.submit(fib, 8) result = await future assert result == 21 - assert len(s.transition_log) > 50 + assert len(s.state.transition_log) > 50 @gen_cluster(client=True) @@ -6763,7 +6763,7 @@ def test_computation_object_code_dask_compute(client): test_function_code = inspect.getsource(test_computation_object_code_dask_compute) def fetch_comp_code(dask_scheduler): - computations = list(dask_scheduler.computations) + computations = list(dask_scheduler.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -6784,7 +6784,7 @@ async def test_computation_object_code_dask_persist(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_dask_persist.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -6804,7 +6804,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_simple.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6825,7 +6825,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_list_comp.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6847,7 +6847,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_dict_comp.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6867,7 +6867,7 @@ async def test_computation_object_code_client_map(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_client_map.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -6885,7 +6885,7 @@ async def test_computation_object_code_client_compute(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_client_compute.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 From e15ad89898d3484b30a56d2c54ae9fcdbcc2af1e Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 10 Aug 2021 16:10:51 -0400 Subject: [PATCH 09/10] merge from main --- distributed/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0af6f02b04..235ed0fbf1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3119,7 +3119,7 @@ def transition_processing_erred(self, key, _, kwargs): exception = kwargs.get("exception") exception_text: str = kwargs.get("exception_text") traceback = kwargs.get("traceback") - traceback_text: str = kwargs.get("traceback_test") + traceback_text: str = kwargs.get("traceback_text") worker = kwargs.get("worker") ws: WorkerState try: From 956a0f16ac16cb34d44597a0c67248f7fd62bc98 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 10 Aug 2021 16:11:21 -0400 Subject: [PATCH 10/10] add back attributes (maybe temporarily) --- distributed/scheduler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 235ed0fbf1..19ed88feda 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4651,6 +4651,15 @@ async def add_nanny(self, comm): } return msg + def get_task_duration(self, ts: TaskState, default: double = -1) -> double: + return self.state.get_task_duration(ts, default) + + def get_comm_cost(self, *args, **kwargs): + return self.state.get_comm_cost(*args, **kwargs) + + def check_idle_saturated(self, *args, **kwargs): + return self.state.check_idle_saturated(*args, **kwargs) + def update_graph_hlg( self, client=None,