diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 59c1d4a41c..714e109c77 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -35,6 +35,7 @@ from sortedcontainers import SortedDict, SortedSet from tlz import ( compose, + concat, first, groupby, merge, @@ -1923,6 +1924,7 @@ def _task_key_or_none(task: TaskState): return task._key if task is not None else None +@final @cclass class SchedulerState: """Underlying task state of dynamic scheduler @@ -1981,6 +1983,7 @@ class SchedulerState: _idle: "SortedDict[str, WorkerState]" _idle_dv: dict # dict[str, WorkerState] _n_tasks: Py_ssize_t + _plugins: list _resources: dict _saturated: set # set[WorkerState] _running: set # set[WorkerState] @@ -1991,6 +1994,7 @@ class SchedulerState: _replicated_tasks: set _total_nthreads: Py_ssize_t _total_occupancy: double + _transition_log: deque _transitions_table: dict _unknown_durations: dict _unrunnable: set @@ -2047,6 +2051,9 @@ def __init__( self._task_metadata = {} self._total_nthreads = 0 self._total_occupancy = 0 + self._transition_log = deque( + maxlen=dask.config.get("distributed.scheduler.transition-log-length") + ) self._transitions_table = { ("released", "waiting"): self.transition_released_waiting, ("waiting", "released"): self.transition_waiting_released, @@ -2132,6 +2139,14 @@ def idle(self): def n_tasks(self): return self._n_tasks + @property + def plugins(self): + return self._plugins + + @plugins.setter + def plugins(self, val): + self._plugins = val + @property def resources(self): return self._resources @@ -2176,6 +2191,10 @@ def total_occupancy(self): def total_occupancy(self, v: double): self._total_occupancy = v + @property + def transition_log(self): + return self._transition_log + @property def transition_counter(self): return self._transition_counter @@ -2264,10 +2283,11 @@ def new_task( # State Transitions # ##################### - def _transition(self, key, finish: str, *args, **kwargs): + @cfunc + def _transition(self, key, finish: str, args: tuple = None, kwargs: dict = None): """Transition a key from its current state to the finish state - Examples + Examples.stimulus_task_finished -------- >>> self._transition('x', 'waiting') {'x': 'processing'} @@ -2280,7 +2300,6 @@ def _transition(self, key, finish: str, *args, **kwargs): -------- Scheduler.transitions : transitive version of this function """ - parent: SchedulerState = cast(SchedulerState, self) ts: TaskState start: str start_finish: tuple @@ -2292,12 +2311,14 @@ def _transition(self, key, finish: str, *args, **kwargs): new_msgs: list dependents: set dependencies: set + args: tuple = args or () + kwargs: dict = kwargs or {} try: recommendations = {} worker_msgs = {} client_msgs = {} - ts = parent._tasks.get(key) # type: ignore + ts = self._tasks.get(key) # type: ignore if ts is None: return recommendations, client_msgs, worker_msgs start = ts._state @@ -2362,12 +2383,8 @@ def _transition(self, key, finish: str, *args, **kwargs): raise RuntimeError("Impossible transition from %r to %r" % start_finish) finish2 = ts._state - # FIXME downcast antipattern - scheduler = pep484_cast(Scheduler, self) - scheduler.transition_log.append( - (key, start, finish2, recommendations, time()) - ) - if parent._validate: + self._transition_log.append((key, start, finish2, recommendations, time())) + if self._validate: logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", key, @@ -2381,17 +2398,17 @@ def _transition(self, key, finish: str, *args, **kwargs): if ts._state == "forgotten": ts._dependents = dependents ts._dependencies = dependencies - parent._tasks[ts._key] = ts + self._tasks[ts._key] = ts for plugin in list(self.plugins.values()): try: plugin.transition(key, start, finish2, *args, **kwargs) except Exception: logger.info("Plugin failed with exception", exc_info=True) if ts._state == "forgotten": - del parent._tasks[ts._key] + del self._tasks[ts._key] tg: TaskGroup = ts._group - if ts._state == "forgotten" and tg._name in parent._task_groups: + if ts._state == "forgotten" and tg._name in self._task_groups: # Remove TaskGroup if all tasks are in the forgotten state all_forgotten: bint = True for s in ALL_TASK_STATES: @@ -2400,7 +2417,7 @@ def _transition(self, key, finish: str, *args, **kwargs): break if all_forgotten: ts._prefix._groups.remove(tg) - del parent._task_groups[tg._name] + del self._task_groups[tg._name] return recommendations, client_msgs, worker_msgs except Exception: @@ -2411,6 +2428,7 @@ def _transition(self, key, finish: str, *args, **kwargs): pdb.set_trace() raise + @ccall def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: dict): """Process transitions until none are left @@ -2447,11 +2465,10 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di worker_msgs[w] = new_msgs if self._validate: - # FIXME downcast antipattern - scheduler = pep484_cast(Scheduler, self) for key in keys: scheduler.validate_key(key) + @ccall def transition_released_waiting(self, key): try: ts: TaskState = self._tasks[key] @@ -2507,6 +2524,7 @@ def transition_released_waiting(self, key): pdb.set_trace() raise + @ccall def transition_no_worker_waiting(self, key): try: ts: TaskState = self._tasks[key] @@ -2717,6 +2735,7 @@ def set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> double: return total_duration + @ccall def transition_waiting_processing(self, key): try: ts: TaskState = self._tasks[key] @@ -2762,9 +2781,12 @@ def transition_waiting_processing(self, key): pdb.set_trace() raise - def transition_waiting_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs - ): + @ccall + def transition_waiting_memory(self, key, _, kwargs): + nbytes = kwargs.get("nbytes") + type = kwargs.get("type") + typename: str = kwargs.get("typename") + worker = kwargs.get("worker") try: ws: WorkerState = self._workers_dv[worker] ts: TaskState = self._tasks[key] @@ -2802,16 +2824,18 @@ def transition_waiting_memory( pdb.set_trace() raise + @ccall def transition_processing_memory( self, key, - nbytes=None, - type=None, - typename: str = None, - worker=None, - startstops=None, - **kwargs, + _, + kwargs, ): + nbytes = kwargs.get("nbytes") + type = kwargs.get("type") + typename: str = kwargs.get("typename") + worker = kwargs.get("worker") + startstops = kwargs.get("startstops") ws: WorkerState wws: WorkerState recommendations: dict = {} @@ -2898,6 +2922,7 @@ def transition_processing_memory( pdb.set_trace() raise + @ccall def transition_memory_released(self, key, safe: bint = False): ws: WorkerState try: @@ -2968,6 +2993,7 @@ def transition_memory_released(self, key, safe: bint = False): pdb.set_trace() raise + @ccall def transition_released_erred(self, key): try: ts: TaskState = self._tasks[key] @@ -3013,6 +3039,7 @@ def transition_released_erred(self, key): pdb.set_trace() raise + @ccall def transition_erred_released(self, key): try: ts: TaskState = self._tasks[key] @@ -3061,6 +3088,7 @@ def transition_erred_released(self, key): pdb.set_trace() raise + @ccall def transition_waiting_released(self, key): try: ts: TaskState = self._tasks[key] @@ -3098,6 +3126,7 @@ def transition_waiting_released(self, key): pdb.set_trace() raise + @ccall def transition_processing_released(self, key): try: ts: TaskState = self._tasks[key] @@ -3149,17 +3178,14 @@ def transition_processing_released(self, key): pdb.set_trace() raise - def transition_processing_erred( - self, - key: str, - cause: str = None, - exception=None, - traceback=None, - exception_text: str = None, - traceback_text: str = None, - worker: str = None, - **kwargs, - ): + @ccall + def transition_processing_erred(self, key, _, kwargs): + cause = kwargs.get("cause") + exception = kwargs.get("exception") + exception_text: str = kwargs.get("exception_text") + traceback = kwargs.get("traceback") + traceback_text: str = kwargs.get("traceback_text") + worker = kwargs.get("worker") ws: WorkerState try: ts: TaskState = self._tasks[key] @@ -3238,6 +3264,7 @@ def transition_processing_erred( pdb.set_trace() raise + @ccall def transition_no_worker_released(self, key): try: ts: TaskState = self._tasks[key] @@ -3281,6 +3308,7 @@ def remove_key(self, key): ts._exception_blame = ts._exception = ts._traceback = None self._task_metadata.pop(key, None) + @ccall def transition_memory_forgotten(self, key): ws: WorkerState try: @@ -3323,6 +3351,7 @@ def transition_memory_forgotten(self, key): pdb.set_trace() raise + @ccall def transition_released_forgotten(self, key): try: ts: TaskState = self._tasks[key] @@ -3459,7 +3488,7 @@ def valid_workers(self, ts: TaskState) -> set: # set[WorkerState] | None if ts._host_restrictions: # Resolve the alias here rather than early, for the worker # may not be connected when host_restrictions is populated - hr: list = [self.coerce_hostname(h) for h in ts._host_restrictions] + hr: list = [self.state.coerce_hostname(h) for h in ts._host_restrictions] # XXX need HostState? sl: list = [] for h in hr: @@ -3602,8 +3631,145 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): for ts in ws._processing: steal.recalculate_cost(ts) + ################### + # Task Validation # + ################### + # TODO: could all be @ccall, but called rarely + + def validate_released(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._state == "released" + assert not ts._waiters + assert not ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert not any([ts in dts._waiters for dts in ts._dependencies]) + assert ts not in self._unrunnable + + def validate_waiting(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._waiting_on + assert not ts._who_has + assert not ts._processing_on + assert ts not in self._unrunnable + for dts in ts._dependencies: + # We are waiting on a dependency iff it's not stored + assert (not not dts._who_has) != (dts in ts._waiting_on) + assert ts in dts._waiters # XXX even if dts._who_has? + + def validate_processing(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert not ts._waiting_on + ws: WorkerState = ts._processing_on + assert ws + assert ts in ws._processing + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + assert ts in dts._waiters + + def validate_memory(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts._who_has + assert bool(ts in self._replicated_tasks) == (len(ts._who_has) > 1) + assert not ts._processing_on + assert not ts._waiting_on + assert ts not in self._unrunnable + for dts in ts._dependents: + assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) + assert ts not in dts._waiting_on + + def validate_no_worker(self, key): + ts: TaskState = self._tasks[key] + dts: TaskState + assert ts in self._unrunnable + assert not ts._waiting_on + assert not ts._processing_on + assert not ts._who_has + for dts in ts._dependencies: + assert dts._who_has + + def validate_erred(self, key): + ts: TaskState = self._tasks[key] + assert ts._exception_blame + assert not ts._who_has + + def validate_key(self, key, ts: TaskState = None): + try: + if ts is None: + ts = self._tasks.get(key) + if ts is None: + logger.debug("Key lost: %s", key) + else: + ts.validate() + try: + func = getattr(self, "validate_" + ts._state.replace("-", "_")) + except AttributeError: + logger.error( + "self.validate_%s not found", ts._state.replace("-", "_") + ) + else: + func(key) + except Exception as e: + logger.exception(e) + if LOG_PDB: + import pdb + + pdb.set_trace() + raise + + def validate_state(self, allow_overlap=False): + validate_state(self._tasks, self._workers, self._clients) + + # if not (set(self.state._workers_dv) == set(self.stream_comms)): + # raise ValueError("Workers not the same in all collections") + + ws: WorkerState + for w, ws in self._workers_dv.items(): + assert isinstance(w, str), (type(w), w) + assert isinstance(ws, WorkerState), (type(ws), ws) + assert ws._address == w + if not ws._processing: + assert not ws._occupancy + assert ws._address in self._idle_dv + + ts: TaskState + for k, ts in self._tasks.items(): + assert isinstance(ts, TaskState), (type(ts), ts) + assert ts._key == k + self.validate_key(k, ts) + + c: str + cs: ClientState + for c, cs in self._clients.items(): + # client=None is often used in tests... + assert c is None or type(c) == str, (type(c), c) + assert type(cs) == ClientState, (type(cs), cs) + assert cs._client_key == c + + a = {w: ws._nbytes for w, ws in self._workers_dv.items()} + b = { + w: sum(ts.get_nbytes() for ts in ws._has_what) + for w, ws in self._workers_dv.items() + } + assert a == b, (a, b) + + actual_total_occupancy = 0 + for worker, ws in self._workers_dv.items(): + assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 + actual_total_occupancy += ws._occupancy + + assert abs(actual_total_occupancy - self._total_occupancy) < 1e-8, ( + actual_total_occupancy, + self._total_occupancy, + ) -class Scheduler(SchedulerState, ServerNode): + +class Scheduler(ServerNode): """Dynamic distributed task scheduler The scheduler tracks the current state of workers, data, and computations. @@ -3761,18 +3927,19 @@ def __init__( http_server_modules = dask.config.get("distributed.scheduler.http.routes") show_dashboard = dashboard or (dashboard is None and dashboard_address) + missing_bokeh = False # install vanilla route if show_dashboard but bokeh is not installed if show_dashboard: try: import distributed.dashboard.scheduler except ImportError: - show_dashboard = False + missing_bokeh = True http_server_modules.append("distributed.http.scheduler.missing_bokeh") routes = get_handlers( server=self, modules=http_server_modules, prefix=http_prefix ) self.start_http_server(routes, dashboard_address, default_port=8787) - if show_dashboard: + if show_dashboard and not missing_bokeh: distributed.dashboard.scheduler.connect( self.http_application, self.http_server, self, prefix=http_prefix ) @@ -3865,18 +4032,15 @@ def __init__( resources = {} aliases = {} - self._task_state_collections = [unrunnable] + self._task_state_collections: list = [unrunnable] - self._worker_collections = [ + self._worker_collections: list = [ workers, host_info, resources, aliases, ] - self.transition_log = deque( - maxlen=dask.config.get("distributed.scheduler.transition-log-length") - ) self.log = deque( maxlen=dask.config.get("distributed.scheduler.transition-log-length") ) @@ -3890,7 +4054,7 @@ def __init__( self.worker_plugins = {} self.nanny_plugins = {} - worker_handlers = { + worker_handlers: dict = { "task-finished": self.handle_task_finished, "task-erred": self.handle_task_erred, "release-worker-data": self.release_worker_data, @@ -3903,7 +4067,7 @@ def __init__( "worker-status-change": self.handle_worker_status_change, } - client_handlers = { + client_handlers: dict = { "update-graph": self.update_graph, "update-graph-hlg": self.update_graph_hlg, "client-desires-keys": self.client_desires_keys, @@ -3973,9 +4137,7 @@ def __init__( } connection_limit = get_fileno_limit() / 2 - - super().__init__( - # Arguments to SchedulerState + self.state: SchedulerState = SchedulerState( aliases=aliases, clients=clients, workers=workers, @@ -3985,7 +4147,8 @@ def __init__( unrunnable=unrunnable, validate=validate, plugins=plugins, - # Arguments to ServerNode + ) + super().__init__( handlers=self.handlers, stream_handlers=merge(worker_handlers, client_handlers), io_loop=self.loop, @@ -3994,7 +4157,7 @@ def __init__( connection_args=self.connection_args, **kwargs, ) - + self.plugins = list(plugins) if self.worker_ttl: pc = PeriodicCallback(self.check_worker_ttl, self.worker_ttl * 1000) self.periodic_callbacks["worker-ttl"] = pc @@ -4007,6 +4170,18 @@ def __init__( extensions = list(DEFAULT_EXTENSIONS) if dask.config.get("distributed.scheduler.work-stealing"): extensions.append(WorkStealing) + + # Set up ServerNode stuff + super().__init__( + handlers=self.handlers, + stream_handlers=merge(worker_handlers, client_handlers), + io_loop=self.loop, + connection_limit=connection_limit, + deserialize=False, + connection_args=self.connection_args, + **kwargs, + ) + for ext in extensions: ext(self) @@ -4015,31 +4190,136 @@ def __init__( self.rpc.allow_offload = False self.status = Status.undefined + #################### + # state properties # + #################### + + @property + def aliases(self): + return self.state.aliases + + @property + def bandwidth(self): + return self.state.bandwidth + + @property + def clients(self): + return self.state.clients + + @property + def extensions(self): + return self.state.extensions + + @property + def host_info(self): + return self.state.host_info + + @property + def idle(self): + return self.state.idle + + @property + def n_tasks(self): + return self.state.n_tasks + + @property + def plugins(self): + return self.state.plugins + + @plugins.setter + def plugins(self, val): + self.state.plugins = val + + @property + def resources(self): + return self.state.resources + + @property + def saturated(self): + return self.state.saturated + + @property + def tasks(self): + return self.state.tasks + + @property + def task_groups(self): + return self.state.task_groups + + @property + def task_prefixes(self): + return self.state.task_prefixes + + @property + def task_metadata(self): + return self.state.task_metadata + + @property + def total_nthreads(self): + return self.state.total_nthreads + + @property + def total_occupancy(self): + return self.state.total_occupancy + + @total_occupancy.setter + def total_occupancy(self, v: double): + self.state.total_occupancy = v + + @property + def transition_counter(self): + return self.state.transition_counter + + @property + def unknown_durations(self): + return self.state.unknown_durations + + @property + def unrunnable(self): + return self.state.unrunnable + + @property + def validate(self): + return self.state.validate + + @validate.setter + def validate(self, v: bint): + self.state.validate = v + + @property + def workers(self): + return self.state.workers + + @property + def memory(self) -> MemoryState: + return MemoryState.sum(*(w.memory for w in self.state.workers.values())) + + @property + def __pdict__(self): + return self.state.__pdict__ + ################## # Administration # ################## def __repr__(self): - parent: SchedulerState = cast(SchedulerState, self) return ( f"" + f"workers: {len(self.state._workers_dv)}, " + f"cores: {self.state._total_nthreads}, " + f"tasks: {len(self.state._tasks)}>" ) def _repr_html_(self): - parent: SchedulerState = cast(SchedulerState, self) return get_template("scheduler.html.j2").render( address=self.address, - workers=parent._workers_dv, - threads=parent._total_nthreads, - tasks=parent._tasks, + workers=self.state._workers_dv, + threads=self.state._total_nthreads, + tasks=self.state._tasks, ) def identity(self, comm=None): """Basic information about ourselves and our cluster""" - parent: SchedulerState = cast(SchedulerState, self) d = { "type": type(self).__name__, "id": str(self.id), @@ -4048,7 +4328,7 @@ def identity(self, comm=None): "started": self.time_started, "workers": { worker.address: worker.identity() - for worker in parent._workers_dv.values() + for worker in self.state._workers_dv.values() }, } return d @@ -4096,8 +4376,7 @@ def get_worker_service_addr(self, worker, service_name, protocol=False): Whether or not to include a full address with protocol (True) or just a (host, port) pair """ - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv[worker] + ws: WorkerState = self.state._workers_dv[worker] port = ws._services.get(service_name) if port is None: return None @@ -4180,7 +4459,6 @@ async def close(self, comm=None, fast=False, close_workers=False): -------- Scheduler.cleanup """ - parent: SchedulerState = cast(SchedulerState, self) if self.status in (Status.closing, Status.closed): await self.finished() return @@ -4194,13 +4472,13 @@ async def close(self, comm=None, fast=False, close_workers=False): if close_workers: await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) - for worker in parent._workers_dv: + for worker in self.state._workers_dv: # Report would require the worker to unregister with the # currently closing scheduler. This is not necessary and might # delay shutdown of the worker unnecessarily self.worker_send(worker, {"op": "close", "report": False}) for i in range(20): # wait a second for send signals to clear - if parent._workers_dv: + if self.state._workers_dv: await asyncio.sleep(0.05) else: break @@ -4215,7 +4493,7 @@ async def close(self, comm=None, fast=False, close_workers=False): self.stop_services() - for ext in parent._extensions.values(): + for ext in self.state._extensions.values(): with suppress(AttributeError): ext.teardown() logger.info("Scheduler closing all comms") @@ -4273,10 +4551,9 @@ def heartbeat_worker( metrics: dict, executing: dict = None, ): - parent: SchedulerState = cast(SchedulerState, self) address = self.coerce_address(address, resolve_address) address = normalize_address(address) - ws: WorkerState = parent._workers_dv.get(address) # type: ignore + ws: WorkerState = self.state._workers_dv.get(address) # type: ignore if ws is None: return {"status": "missing"} @@ -4284,12 +4561,12 @@ def heartbeat_worker( local_now = time() host_info = host_info or {} - dh: dict = parent._host_info.setdefault(host, {}) + dh: dict = self.state._host_info.setdefault(host, {}) dh["last-seen"] = local_now - frac = 1 / len(parent._workers_dv) - parent._bandwidth = ( - parent._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac + frac = 1 / len(self.state._workers_dv) + self.state._bandwidth = ( + self.state._bandwidth * (1 - frac) + metrics["bandwidth"]["total"] * frac ) for other, (bw, count) in metrics["bandwidth"]["workers"].items(): if (address, other) not in self.bandwidth_workers: @@ -4311,16 +4588,16 @@ def heartbeat_worker( ws._last_seen = local_now if executing is not None: ws._executing = { - parent._tasks[key]: duration - for key, duration in executing.items() - if key in parent._tasks + self.state._tasks[key]: duration for key, duration in executing.items() } ws._metrics = metrics # Calculate RSS - dask keys, separating "old" and "new" usage # See MemoryState for details - max_memory_unmanaged_old_hist_age = local_now - parent.MEMORY_RECENT_TO_OLD_TIME + max_memory_unmanaged_old_hist_age = ( + local_now - self.state.MEMORY_RECENT_TO_OLD_TIME + ) memory_unmanaged_old = ws._memory_unmanaged_old while ws._memory_other_history: timestamp, size = ws._memory_other_history[0] @@ -4350,7 +4627,7 @@ def heartbeat_worker( ws._memory_unmanaged_old = size if host_info: - dh = parent._host_info.setdefault(host, {}) + dh = self.state._host_info.setdefault(host, {}) dh.update(host_info) if now: @@ -4364,7 +4641,7 @@ def heartbeat_worker( return { "status": "OK", "time": local_now, - "heartbeat-interval": heartbeat_interval(len(parent._workers_dv)), + "heartbeat-interval": heartbeat_interval(len(self.state._workers_dv)), } async def add_worker( @@ -4392,16 +4669,16 @@ async def add_worker( extra=None, ): """Add a new worker to the cluster""" - parent: SchedulerState = cast(SchedulerState, self) + with log_errors(): address = self.coerce_address(address, resolve_address) address = normalize_address(address) host = get_address_host(address) - if address in parent._workers_dv: + if address in self.state._workers_dv: raise ValueError("Worker already exists %s" % address) - if name in parent._aliases: + if name in self.state._aliases: logger.warning( "Worker tried to connect with a duplicate name: %s", name ) @@ -4418,7 +4695,7 @@ async def add_worker( self.log_event("all", {"action": "add-worker", "worker": address}) ws: WorkerState - parent._workers[address] = ws = WorkerState( + self.state._workers[address] = ws = WorkerState( address=address, status=Status.lookup[status], # type: ignore pid=pid, @@ -4432,11 +4709,11 @@ async def add_worker( extra=extra, ) if ws._status == Status.running: - parent._running.add(ws) + self.state._running.add(ws) - dh: dict = parent._host_info.get(host) # type: ignore + dh: dict = self.state._host_info.get(host) # type: ignore if dh is None: - parent._host_info[host] = dh = {} + self.state._host_info[host] = dh = {} dh_addresses: set = dh.get("addresses") # type: ignore if dh_addresses is None: @@ -4446,8 +4723,8 @@ async def add_worker( dh_addresses.add(address) dh["nthreads"] += nthreads - parent._total_nthreads += nthreads - parent._aliases[name] = address + self.state._total_nthreads += nthreads + self.state._aliases[name] = address self.heartbeat_worker( address=address, @@ -4458,9 +4735,9 @@ async def add_worker( metrics=metrics, ) - # Do not need to adjust parent._total_occupancy as self.occupancy[ws] cannot + # Do not need to adjust self.state._total_occupancy as self.occupancy[ws] cannot # exist before this. - self.check_idle_saturated(ws) + self.state.check_idle_saturated(ws) # for key in keys: # TODO # self.mark_key_in_memory(key, [address]) @@ -4468,7 +4745,7 @@ async def add_worker( self.stream_comms[address] = BatchedSend(interval="5ms", loop=self.loop) if ws._nthreads > len(ws._processing): - parent._idle[ws._address] = ws + self.state._idle[ws._address] = ws for plugin in list(self.plugins.values()): try: @@ -4485,20 +4762,23 @@ async def add_worker( assert isinstance(nbytes, dict) already_released_keys = [] for key in nbytes: - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = self.state._tasks.get(key) # type: ignore if ts is not None and ts.state != "released": if ts.state == "memory": self.add_keys(worker=address, keys=[key]) else: - t: tuple = parent._transition( + t: tuple = self.state._transition( key, "memory", - worker=address, - nbytes=nbytes[key], - typename=types[key], + args=None, + kwargs=dict( + worker=address, + nbytes=nbytes[key], + typename=types[key], + ), ) recommendations, client_msgs, worker_msgs = t - parent._transitions( + self.state._transitions( recommendations, client_msgs, worker_msgs ) recommendations = {} @@ -4516,13 +4796,13 @@ async def add_worker( ) if ws._status == Status.running: - for ts in parent._unrunnable: - valid: set = self.valid_workers(ts) + for ts in self.state._unrunnable: + valid: set = self.state.valid_workers(ts) if valid is None or ws in valid: recommendations[ts._key] = "waiting" if recommendations: - parent._transitions(recommendations, client_msgs, worker_msgs) + self.state._transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) @@ -4531,7 +4811,7 @@ async def add_worker( msg = { "status": "OK", "time": time(), - "heartbeat-interval": heartbeat_interval(len(parent._workers_dv)), + "heartbeat-interval": heartbeat_interval(len(self.state._workers_dv)), "worker-plugins": self.worker_plugins, } @@ -4539,10 +4819,10 @@ async def add_worker( version_warning = version_module.error_message( version_module.get_versions(), merge( - {w: ws._versions for w, ws in parent._workers_dv.items()}, + {w: ws._versions for w, ws in self.state._workers_dv.items()}, { c: cs._versions - for c, cs in parent._clients.items() + for c, cs in self.state._clients.items() if cs._versions }, ), @@ -4563,6 +4843,15 @@ async def add_nanny(self, comm): } return msg + def get_task_duration(self, ts: TaskState, default: double = -1) -> double: + return self.state.get_task_duration(ts, default) + + def get_comm_cost(self, *args, **kwargs): + return self.state.get_comm_cost(*args, **kwargs) + + def check_idle_saturated(self, *args, **kwargs): + return self.state.check_idle_saturated(*args, **kwargs) + def update_graph_hlg( self, client=None, @@ -4643,7 +4932,7 @@ def update_graph( This happens whenever the Client calls submit, map, get, or compute. """ - parent: SchedulerState = cast(SchedulerState, self) + start = time() fifo_timeout = parse_timedelta(fifo_timeout) keys = set(keys) @@ -4659,12 +4948,12 @@ def update_graph( dependencies = dependencies or {} - if parent._total_occupancy > 1e-9 and parent._computations: + if self.state._total_occupancy > 1e-9 and self.state._computations: # Still working on something. Assign new tasks to same computation - computation = cast(Computation, parent._computations[-1]) + computation = cast(Computation, self.state._computations[-1]) else: computation = Computation() - parent._computations.append(computation) + self.state._computations.append(computation) if code and code not in computation._code: # add new code blocks computation._code.add(code) @@ -4674,7 +4963,7 @@ def update_graph( n = len(tasks) for k, deps in list(dependencies.items()): if any( - dep not in parent._tasks and dep not in tasks for dep in deps + dep not in self.state._tasks and dep not in tasks for dep in deps ): # bad key logger.info("User asked for computation on lost data, %s", k) del tasks[k] @@ -4688,8 +4977,8 @@ def update_graph( ts: TaskState already_in_memory = set() # tasks that are already done for k, v in dependencies.items(): - if v and k in parent._tasks: - ts = parent._tasks[k] + if v and k in self.state._tasks: + ts = self.state._tasks[k] if ts._state in ("memory", "erred"): already_in_memory.add(k) @@ -4700,7 +4989,7 @@ def update_graph( done = set(already_in_memory) while stack: # remove unnecessary dependencies key = stack.pop() - ts = parent._tasks[key] + ts = self.state._tasks[key] try: deps = dependencies[key] except KeyError: @@ -4711,7 +5000,7 @@ def update_graph( else: child_deps = self.dependencies[dep] if all(d in done for d in child_deps): - if dep in parent._tasks and dep not in done: + if dep in self.state._tasks and dep not in done: done.add(dep) stack.append(dep) @@ -4728,9 +5017,9 @@ def update_graph( if k in touched_keys: continue # XXX Have a method get_task_state(self, k) ? - ts = parent._tasks.get(k) + ts = self.state._tasks.get(k) if ts is None: - ts = parent.new_task( + ts = self.state.new_task( k, tasks.get(k), "released", computation=computation ) elif not ts._run_spec: @@ -4744,11 +5033,11 @@ def update_graph( # Add dependencies for key, deps in dependencies.items(): - ts = parent._tasks.get(key) + ts = self.state._tasks.get(key) if ts is None or ts._dependencies: continue for dep in deps: - dts = parent._tasks[dep] + dts = self.state._tasks[dep] ts.add_dependency(dts) # Compute priorities @@ -4784,7 +5073,7 @@ def update_graph( for k, v in kv.items(): # Tasks might have been culled, in which case # we have nothing to annotate. - ts = parent._tasks.get(k) + ts = self.state._tasks.get(k) if ts is not None: ts._annotations[a] = v @@ -4792,7 +5081,7 @@ def update_graph( if actors is True: actors = list(keys) for actor in actors or []: - ts = parent._tasks[actor] + ts = self.state._tasks[actor] ts._actor = True priority = priority or dask.order.order( @@ -4800,7 +5089,7 @@ def update_graph( ) # TODO: define order wrt old graph if submitting_task: # sub-tasks get better priority than parent tasks - ts = parent._tasks.get(submitting_task) + ts = self.state._tasks.get(submitting_task) if ts is not None: generation = ts._priority[0] - 0.01 else: # super-task already cleaned up @@ -4813,7 +5102,7 @@ def update_graph( generation = self.generation for key in set(priority) & touched_keys: - ts = parent._tasks[key] + ts = self.state._tasks[key] if ts._priority is None: ts._priority = (-(user_priority.get(key, 0)), generation, priority[key]) @@ -4829,7 +5118,7 @@ def update_graph( for k, v in restrictions.items(): if v is None: continue - ts = parent._tasks.get(k) + ts = self.state._tasks.get(k) if ts is None: continue ts._host_restrictions = set() @@ -4848,7 +5137,7 @@ def update_graph( if loose_restrictions: for k in loose_restrictions: - ts = parent._tasks[k] + ts = self.state._tasks[k] ts._loose_restrictions = True if resources: @@ -4856,7 +5145,7 @@ def update_graph( if v is None: continue assert isinstance(v, dict) - ts = parent._tasks.get(k) + ts = self.state._tasks.get(k) if ts is None: continue ts._resource_restrictions = v @@ -4864,7 +5153,7 @@ def update_graph( if retries: for k, v in retries.items(): assert isinstance(v, int) - ts = parent._tasks.get(k) + ts = self.state._tasks.get(k) if ts is None: continue ts._retries = v @@ -4914,15 +5203,14 @@ def update_graph( def stimulus_task_finished(self, key=None, worker=None, **kwargs): """Mark that a task has finished execution on a particular worker""" - parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task finished %s, %s", key, worker) recommendations: dict = {} client_msgs: dict = {} worker_msgs: dict = {} - ws: WorkerState = parent._workers_dv[worker] - ts: TaskState = parent._tasks.get(key) + ws: WorkerState = self.state._workers_dv[worker] + ts: TaskState = self.state._tasks.get(key) if ts is None or ts._state == "released": logger.debug( "Received already computed task, worker: %s, state: %s" @@ -4943,7 +5231,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): self.add_keys(worker=worker, keys=[key]) else: ts._metadata.update(kwargs["metadata"]) - r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) + r: tuple = self.state._transition(key, "memory", worker=worker, **kwargs) recommendations, client_msgs, worker_msgs = r if ts._state == "memory": @@ -4954,18 +5242,19 @@ def stimulus_task_erred( self, key=None, worker=None, exception=None, traceback=None, **kwargs ): """Mark that a task has erred on a particular worker""" - parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state._tasks.get(key) if ts is None or ts._state != "processing": return {}, {}, {} - if ts._retries > 0: - ts._retries -= 1 - return parent._transition(key, "waiting") + retries: Py_ssize_t = ts._retries + if retries > 0: + retries -= 1 + ts._retries = retries + return self.state._transition(key, "waiting") else: - return parent._transition( + return self.state._transition( key, "erred", cause=key, @@ -4976,7 +5265,7 @@ def stimulus_task_erred( ) def stimulus_retry(self, comm=None, keys=None, client=None): - parent: SchedulerState = cast(SchedulerState, self) + logger.info("Client %s requests to retry %d keys", client, len(keys)) if client: self.log_event(client, {"action": "retry", "count": len(keys)}) @@ -4989,7 +5278,7 @@ def stimulus_retry(self, comm=None, keys=None, client=None): while stack: key = stack.pop() seen.add(key) - ts = parent._tasks[key] + ts = self.state._tasks[key] erred_deps = [dts._key for dts in ts._dependencies if dts._state == "erred"] if erred_deps: stack.extend(erred_deps) @@ -4999,9 +5288,9 @@ def stimulus_retry(self, comm=None, keys=None, client=None): recommendations: dict = {key: "waiting" for key in roots} self.transitions(recommendations) - if parent._validate: + if self.state._validate: for key in seen: - assert not parent._tasks[key].exception_blame + assert not self.state._tasks[key].exception_blame return tuple(seen) @@ -5013,19 +5302,19 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): appears to be unresponsive. This may send its tasks back to a released state. """ - parent: SchedulerState = cast(SchedulerState, self) + with log_errors(): if self.status == Status.closed: return address = self.coerce_address(address) - if address not in parent._workers_dv: + if address not in self.state._workers_dv: return "already-removed" host = get_address_host(address) - ws: WorkerState = parent._workers_dv[address] + ws: WorkerState = self.state._workers_dv[address] event_msg = { "action": "remove-worker", @@ -5042,23 +5331,23 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): self.remove_resources(address) - dh: dict = parent._host_info[host] + dh: dict = self.state._host_info[host] dh_addresses: set = dh["addresses"] dh_addresses.remove(address) dh["nthreads"] -= ws._nthreads - parent._total_nthreads -= ws._nthreads + self.state._total_nthreads -= ws._nthreads if not dh_addresses: - del parent._host_info[host] + del self.state._host_info[host] self.rpc.remove(address) del self.stream_comms[address] - del parent._aliases[ws._name] - parent._idle.pop(ws._address, None) - parent._saturated.discard(ws) - del parent._workers[address] + del self.state._aliases[ws._name] + self.state._idle.pop(ws._address, None) + self.state._saturated.discard(ws) + del self.state._workers[address] ws.status = Status.closed - parent._running.discard(ws) - parent._total_occupancy -= ws._occupancy + self.state._running.discard(ws) + self.state._total_occupancy -= ws._occupancy recommendations: dict = {} @@ -5084,7 +5373,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): ) for ts in list(ws._has_what): - parent.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) if not ts._who_has: if ts._run_spec: recommendations[ts._key] = "released" @@ -5101,16 +5390,16 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): except Exception as e: logger.exception(e) - if not parent._workers_dv: + if not self.state._workers_dv: logger.info("Lost all workers") - for w in parent._workers_dv: + for w in self.state._workers_dv: self.bandwidth_workers.pop((address, w), None) self.bandwidth_workers.pop((w, address), None) def remove_worker_from_events(): # If the worker isn't registered anymore after the delay, remove from events - if address not in parent._workers_dv and address in self.events: + if address not in self.state._workers_dv and address in self.events: del self.events[address] cleanup_delay = parse_timedelta( @@ -5134,11 +5423,11 @@ def stimulus_cancel(self, comm, keys=None, client=None, force=False): def cancel_key(self, key, client, retries=5, force=False): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks.get(key) + + ts: TaskState = self.state._tasks.get(key) dts: TaskState try: - cs: ClientState = parent._clients[client] + cs: ClientState = self.state._clients[client] except KeyError: return if ts is None or not ts._who_wants: # no key yet, lets try again in a moment @@ -5157,17 +5446,17 @@ def cancel_key(self, key, client, retries=5, force=False): self.client_releases_keys(keys=[key], client=cs._client_key) def client_desires_keys(self, keys=None, client=None): - parent: SchedulerState = cast(SchedulerState, self) - cs: ClientState = parent._clients.get(client) + + cs: ClientState = self.state._clients.get(client) if cs is None: # For publish, queues etc. - parent._clients[client] = cs = ClientState(client) + self.state._clients[client] = cs = ClientState(client) ts: TaskState for k in keys: - ts = parent._tasks.get(k) + ts = self.state._tasks.get(k) if ts is None: # For publish, queues etc. - ts = parent.new_task(k, None, "released") + ts = self.state.new_task(k, None, "released") ts._who_wants.add(cs) cs._wants_what.add(ts) @@ -5177,175 +5466,24 @@ def client_desires_keys(self, keys=None, client=None): def client_releases_keys(self, keys=None, client=None): """Remove keys from client desired list""" - parent: SchedulerState = cast(SchedulerState, self) if not isinstance(keys, list): keys = list(keys) - cs: ClientState = parent._clients[client] + cs: ClientState = self.state._clients[client] recommendations: dict = {} - _client_releases_keys(parent, keys=keys, cs=cs, recommendations=recommendations) + _client_releases_keys( + self.state, keys=keys, cs=cs, recommendations=recommendations + ) self.transitions(recommendations) def client_heartbeat(self, client=None): """Handle heartbeats from Client""" - parent: SchedulerState = cast(SchedulerState, self) - cs: ClientState = parent._clients[client] - cs._last_seen = time() - - ################### - # Task Validation # - ################### - - def validate_released(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._state == "released" - assert not ts._waiters - assert not ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert not any([ts in dts._waiters for dts in ts._dependencies]) - assert ts not in parent._unrunnable - - def validate_waiting(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._waiting_on - assert not ts._who_has - assert not ts._processing_on - assert ts not in parent._unrunnable - for dts in ts._dependencies: - # We are waiting on a dependency iff it's not stored - assert bool(dts._who_has) != (dts in ts._waiting_on) - assert ts in dts._waiters # XXX even if dts._who_has? - def validate_processing(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert not ts._waiting_on - ws: WorkerState = ts._processing_on - assert ws - assert ts in ws._processing - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has - assert ts in dts._waiters - - def validate_memory(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts._who_has - assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) - assert not ts._processing_on - assert not ts._waiting_on - assert ts not in parent._unrunnable - for dts in ts._dependents: - assert (dts in ts._waiters) == (dts._state in ("waiting", "processing")) - assert ts not in dts._waiting_on - - def validate_no_worker(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - dts: TaskState - assert ts in parent._unrunnable - assert not ts._waiting_on - assert ts in parent._unrunnable - assert not ts._processing_on - assert not ts._who_has - for dts in ts._dependencies: - assert dts._who_has - - def validate_erred(self, key): - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState = parent._tasks[key] - assert ts._exception_blame - assert not ts._who_has - - def validate_key(self, key, ts: TaskState = None): - parent: SchedulerState = cast(SchedulerState, self) - try: - if ts is None: - ts = parent._tasks.get(key) - if ts is None: - logger.debug("Key lost: %s", key) - else: - ts.validate() - try: - func = getattr(self, "validate_" + ts._state.replace("-", "_")) - except AttributeError: - logger.error( - "self.validate_%s not found", ts._state.replace("-", "_") - ) - else: - func(key) - except Exception as e: - logger.exception(e) - if LOG_PDB: - import pdb - - pdb.set_trace() - raise - - def validate_state(self, allow_overlap=False): - parent: SchedulerState = cast(SchedulerState, self) - validate_state(parent._tasks, parent._workers, parent._clients) - - if not (set(parent._workers_dv) == set(self.stream_comms)): - raise ValueError("Workers not the same in all collections") - - ws: WorkerState - for w, ws in parent._workers_dv.items(): - assert isinstance(w, str), (type(w), w) - assert isinstance(ws, WorkerState), (type(ws), ws) - assert ws._address == w - if not ws._processing: - assert not ws._occupancy - assert ws._address in parent._idle_dv - assert (ws._status == Status.running) == (ws in parent._running) - - for ws in parent._running: - assert ws._status == Status.running - assert ws._address in parent._workers_dv - - ts: TaskState - for k, ts in parent._tasks.items(): - assert isinstance(ts, TaskState), (type(ts), ts) - assert ts._key == k - assert bool(ts in parent._replicated_tasks) == (len(ts._who_has) > 1) - self.validate_key(k, ts) - - for ts in parent._replicated_tasks: - assert ts._state == "memory" - assert ts._key in parent._tasks - - c: str - cs: ClientState - for c, cs in parent._clients.items(): - # client=None is often used in tests... - assert c is None or type(c) == str, (type(c), c) - assert type(cs) == ClientState, (type(cs), cs) - assert cs._client_key == c - - a = {w: ws._nbytes for w, ws in parent._workers_dv.items()} - b = { - w: sum(ts.get_nbytes() for ts in ws._has_what) - for w, ws in parent._workers_dv.items() - } - assert a == b, (a, b) - - actual_total_occupancy = 0 - for worker, ws in parent._workers_dv.items(): - assert abs(sum(ws._processing.values()) - ws._occupancy) < 1e-8 - actual_total_occupancy += ws._occupancy + cs: ClientState = self.state._clients[client] + cs._last_seen = time() - assert abs(actual_total_occupancy - parent._total_occupancy) < 1e-8, ( - actual_total_occupancy, - parent._total_occupancy, - ) + def validate_state(self): + self.state.validate_state() ################### # Manage Messages # @@ -5358,11 +5496,11 @@ def report(self, msg: dict, ts: TaskState = None, client: str = None): If the message contains a key then we only send the message to those comms that care about the key. """ - parent: SchedulerState = cast(SchedulerState, self) + if ts is None: msg_key = msg.get("key") if msg_key is not None: - tasks: dict = parent._tasks + tasks: dict = self.state._tasks ts = tasks.get(msg_key) cs: ClientState @@ -5400,12 +5538,12 @@ async def add_client(self, comm, client=None, versions=None): We listen to all future messages from this Comm. """ - parent: SchedulerState = cast(SchedulerState, self) + assert client is not None comm.name = "Scheduler->Client" logger.info("Receive client connection: %s", client) self.log_event(["all", client], {"action": "add-client", "client": client}) - parent._clients[client] = ClientState(client, versions=versions) + self.state._clients[client] = ClientState(client, versions=versions) for plugin in list(self.plugins.values()): try: @@ -5421,7 +5559,7 @@ async def add_client(self, comm, client=None, versions=None): ws: WorkerState version_warning = version_module.error_message( version_module.get_versions(), - {w: ws._versions for w, ws in parent._workers_dv.items()}, + {w: ws._versions for w, ws in self.state._workers_dv.items()}, versions, ) msg.update(version_warning) @@ -5446,12 +5584,12 @@ async def add_client(self, comm, client=None, versions=None): def remove_client(self, client=None): """Remove client from network""" - parent: SchedulerState = cast(SchedulerState, self) + if self.status == Status.running: logger.info("Remove client %s", client) self.log_event(["all", client], {"action": "remove-client", "client": client}) try: - cs: ClientState = parent._clients[client] + cs: ClientState = self.state._clients[client] except KeyError: # XXX is this a legitimate condition? pass @@ -5460,7 +5598,7 @@ def remove_client(self, client=None): self.client_releases_keys( keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) - del parent._clients[client] + del self.state._clients[client] for plugin in list(self.plugins.values()): try: @@ -5470,7 +5608,7 @@ def remove_client(self, client=None): def remove_client_from_events(): # If the client isn't registered anymore after the delay, remove from events - if client not in parent._clients and client in self.events: + if client not in self.state._clients and client in self.events: del self.events[client] cleanup_delay = parse_timedelta( @@ -5480,9 +5618,9 @@ def remove_client_from_events(): def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): """Send a single computational task to a worker""" - parent: SchedulerState = cast(SchedulerState, self) + try: - msg: dict = _task_to_msg(parent, ts, duration) + msg: dict = _task_to_msg(self.state, ts, duration) self.worker_send(worker, msg) except Exception as e: logger.exception(e) @@ -5496,8 +5634,8 @@ def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) def handle_task_finished(self, key=None, worker=None, **msg): - parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers_dv: + + if worker not in self.state._workers_dv: return validate_key(key) @@ -5507,18 +5645,18 @@ def handle_task_finished(self, key=None, worker=None, **msg): r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) recommendations, client_msgs, worker_msgs = r - parent._transitions(recommendations, client_msgs, worker_msgs) + self.state._transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) def handle_task_erred(self, key=None, **msg): - parent: SchedulerState = cast(SchedulerState, self) + recommendations: dict client_msgs: dict worker_msgs: dict r: tuple = self.stimulus_task_erred(key=key, **msg) recommendations, client_msgs, worker_msgs = r - parent._transitions(recommendations, client_msgs, worker_msgs) + self.state._transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) @@ -5539,16 +5677,16 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): errant_worker : str, optional Address of the worker supposed to hold a replica, by default None """ - parent: SchedulerState = cast(SchedulerState, self) + logger.debug("handle missing data key=%s worker=%s", key, errant_worker) self.log_event(errant_worker, {"action": "missing-data", "key": key}) - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state._tasks.get(key) if ts is None: return - ws: WorkerState = parent._workers_dv.get(errant_worker) + ws: WorkerState = self.state._workers_dv.get(errant_worker) if ws is not None and ws in ts._who_has: - parent.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) if ts.state == "memory" and not ts._who_has: if ts._run_spec: self.transitions({key: "released"}) @@ -5556,14 +5694,14 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): self.transitions({key: "forgotten"}) def release_worker_data(self, comm=None, key=None, worker=None): - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker) - ts: TaskState = parent._tasks.get(key) + + ws: WorkerState = self.state._workers_dv.get(worker) + ts: TaskState = self.state._tasks.get(key) if not ws or not ts: return recommendations: dict = {} if ws in ts._who_has: - parent.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) if not ts._who_has: recommendations[ts._key] = "released" if recommendations: @@ -5575,12 +5713,12 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): We stop the task from being stolen in the future, and change task duration accounting as if the task has stopped. """ - parent: SchedulerState = cast(SchedulerState, self) - if key not in parent._tasks: + + if key not in self.state._tasks: logger.debug("Skipping long_running since key %s was already released", key) return - ts: TaskState = parent._tasks[key] - steal = parent._extensions.get("stealing") + ts: TaskState = self.state._tasks[key] + steal = self.state._extensions.get("stealing") if steal is not None: steal.remove_key_from_stealable(ts) @@ -5602,18 +5740,18 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): occ: double = ws._processing[ts] ws._occupancy -= occ - parent._total_occupancy -= occ + self.state._total_occupancy -= occ # Cannot remove from processing since we're using this for things like # idleness detection. Idle workers are typically targeted for # downscaling but we should not downscale workers with long running # tasks ws._processing[ts] = 0 ws._long_running.add(ts) - self.check_idle_saturated(ws) + self.state.check_idle_saturated(ws) def handle_worker_status_change(self, status: str, worker: str) -> None: - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker) # type: ignore + + ws: WorkerState = self.state._workers_dv.get(worker) # type: ignore if not ws: return prev_status = ws._status @@ -5631,22 +5769,22 @@ def handle_worker_status_change(self, status: str, worker: str) -> None: ) if ws._status == Status.running: - parent._running.add(ws) + self.state._running.add(ws) recs = {} ts: TaskState - for ts in parent._unrunnable: + for ts in self.state._unrunnable: valid: set = self.valid_workers(ts) if valid is None or ws in valid: recs[ts._key] = "waiting" if recs: client_msgs: dict = {} worker_msgs: dict = {} - parent._transitions(recs, client_msgs, worker_msgs) + self.state._transitions(recs, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) else: - parent._running.discard(ws) + self.state._running.discard(ws) async def handle_worker(self, comm=None, worker=None): """ @@ -5873,16 +6011,16 @@ async def scatter( -------- Scheduler.broadcast: """ - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState start = time() while True: if workers is None: - wss = parent._running + wss = self.state._running else: workers = [self.coerce_address(w) for w in workers] - wss = {parent._workers_dv[w] for w in workers} + wss = {self.state._workers_dv[w] for w in workers} wss = {ws for ws in wss if ws._status == Status.running} if wss: @@ -5911,13 +6049,13 @@ async def scatter( return keys async def gather(self, comm=None, keys=None, serializers=None): - """Collect data from workers to the scheduler""" - parent: SchedulerState = cast(SchedulerState, self) + """Collect data in from workers""" + ws: WorkerState keys = list(keys) who_has = {} for key in keys: - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state._tasks.get(key) if ts is not None: who_has[key] = [ws._address for ws in ts._who_has] else: @@ -5930,7 +6068,7 @@ async def gather(self, comm=None, keys=None, serializers=None): result = {"status": "OK", "data": data} else: missing_states = [ - (parent._tasks[key].state if key in parent._tasks else None) + (self.state._tasks[key].state if key in self.state._tasks else None) for key in missing_keys ] logger.exception( @@ -5955,7 +6093,7 @@ async def gather(self, comm=None, keys=None, serializers=None): for key, workers in missing_keys.items(): # Task may already be gone if it was held by a # `missing_worker` - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state._tasks.get(key) logger.exception( "Workers don't have promised key: %s, %s", str(workers), @@ -5965,10 +6103,10 @@ async def gather(self, comm=None, keys=None, serializers=None): continue recommendations: dict = {key: "released"} for worker in workers: - ws = parent._workers_dv.get(worker) + ws = self.state._workers_dv.get(worker) if ws is not None and ws in ts._who_has: - parent.remove_replica(ts, ws) - parent._transitions( + self.state.remove_replica(ts, ws) + self.state._transitions( recommendations, client_msgs, worker_msgs ) self.send_all(client_msgs, worker_msgs) @@ -5983,25 +6121,25 @@ def clear_task_state(self): for collection in self._task_state_collections: collection.clear() - async def restart(self, client=None, timeout=30): - """Restart all workers. Reset local state.""" - parent: SchedulerState = cast(SchedulerState, self) + async def restart(self, client=None, timeout=3): + """Restart all workers. Reset local state.""" + with log_errors(): - n_workers = len(parent._workers_dv) + n_workers = len(self.state._workers_dv) logger.info("Send lost future signal to clients") cs: ClientState ts: TaskState - for cs in parent._clients.values(): + for cs in self.state._clients.values(): self.client_releases_keys( keys=[ts._key for ts in cs._wants_what], client=cs._client_key ) ws: WorkerState - nannies = {addr: ws._nanny for addr, ws in parent._workers_dv.items()} + nannies = {addr: ws._nanny for addr, ws in self.state._workers_dv.items()} - for addr in list(parent._workers_dv): + for addr in list(self.state._workers_dv): try: # Ask the worker to close if it doesn't have a nanny, # otherwise the nanny will kill it anyway @@ -6058,7 +6196,7 @@ async def restart(self, client=None, timeout=30): self.log_event([client, "all"], {"action": "restart", "client": client}) start = time() - while time() < start + 10 and len(parent._workers_dv) < n_workers: + while time() < start + 10 and len(self.state._workers_dv) < n_workers: await asyncio.sleep(0.01) self.report({"op": "restart"}) @@ -6075,7 +6213,7 @@ async def broadcast( on_error: "Literal['raise', 'return', 'return_pickle', 'ignore']" = "raise", ) -> dict: # dict[str, Any] """Broadcast message to workers, return all results""" - parent: SchedulerState = cast(SchedulerState, self) + if workers is True: warnings.warn( "workers=True is deprecated; pass workers=None or omit instead", @@ -6084,18 +6222,18 @@ async def broadcast( workers = None if workers is None: if hosts is None: - workers = list(parent._workers_dv) + workers = list(self.state._workers_dv) else: workers = [] if hosts is not None: for host in hosts: - dh: dict = parent._host_info.get(host) # type: ignore + dh: dict = self.state._host_info.get(host) # type: ignore if dh is not None: workers.extend(dh["addresses"]) # TODO replace with worker_list if nanny: - addresses = [parent._workers_dv[w].nanny for w in workers] + addresses = [self.state._workers_dv[w].nanny for w in workers] else: addresses = workers @@ -6171,8 +6309,7 @@ async def gather_on_worker( ) return set(who_has) - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore + ws: WorkerState = self.state._workers_dv.get(worker_address) # type: ignore if ws is None: logger.warning(f"Worker {worker_address} lost during replication") @@ -6190,12 +6327,12 @@ async def gather_on_worker( raise ValueError(f"Unexpected message from {worker_address}: {result}") for key in keys_ok: - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = self.state._tasks.get(key) # type: ignore if ts is None or ts._state != "memory": logger.warning(f"Key lost during replication: {key}") continue if ws not in ts._who_has: - parent.add_replica(ts, ws) + self.state.add_replica(ts, ws) return keys_failed @@ -6208,10 +6345,9 @@ async def delete_worker_data( ---------- worker_address: str Worker address to delete keys from - keys: list[str] + keys: List[str] List of keys to delete on the specified worker """ - parent: SchedulerState = cast(SchedulerState, self) try: await retry_operation( @@ -6228,15 +6364,15 @@ async def delete_worker_data( ) return - ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore + ws: WorkerState = self.state._workers_dv.get(worker_address) # type: ignore if ws is None: return for key in keys: - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = self.state._tasks.get(key) # type: ignore if ts is not None and ws in ts._who_has: assert ts._state == "memory" - parent.remove_replica(ts, ws) + self.state.remove_replica(ts, ws) if not ts._who_has: # Last copy deleted self.transitions({key: "released"}) @@ -6314,14 +6450,13 @@ async def rebalance( All other workers will be ignored. The mean cluster occupancy will be calculated only using the allowed workers. """ - parent: SchedulerState = cast(SchedulerState, self) with log_errors(): wss: "Collection[WorkerState]" if workers is not None: - wss = [parent._workers_dv[w] for w in workers] + wss = [self.state._workers_dv[w] for w in workers] else: - wss = parent._workers_dv.values() + wss = self.state._workers_dv.values() if not wss: return {"status": "OK"} @@ -6333,21 +6468,17 @@ async def rebalance( missing_data = [ k for k in keys - if k not in parent._tasks or not parent._tasks[k].who_has + if k not in self.state._tasks or not self.state._tasks[k].who_has ] if missing_data: - return {"status": "partial-fail", "keys": missing_data} + return {"status": "missing-data", "keys": missing_data} msgs = self._rebalance_find_msgs(keys, wss) if not msgs: return {"status": "OK"} async with self._lock: - result = await self._rebalance_move_data(msgs) - if result["status"] == "partial-fail" and keys is None: - # Only return failed keys if the client explicitly asked for them - result = {"status": "OK"} - return result + return await self._rebalance_move_data(msgs) def _rebalance_find_msgs( self, @@ -6380,7 +6511,7 @@ def _rebalance_find_msgs( - recipient worker - task to be transferred """ - parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState ws: WorkerState @@ -6413,15 +6544,18 @@ def _rebalance_find_msgs( # (distributed.worker.memory.recent-to-old-time). # This lets us ignore temporary spikes caused by task heap usage. memory_by_worker = [ - (ws, getattr(ws.memory, parent.MEMORY_REBALANCE_MEASURE)) for ws in workers + (ws, getattr(ws.memory, self.state.MEMORY_REBALANCE_MEASURE)) + for ws in workers ] mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) for ws, ws_memory in memory_by_worker: if ws.memory_limit: - half_gap = int(parent.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit) - sender_min = parent.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit - recipient_max = parent.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit + half_gap = int(self.state.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit) + sender_min = self.state.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit + recipient_max = ( + self.state.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit + ) else: half_gap = 0 sender_min = 0.0 @@ -6504,7 +6638,7 @@ def _rebalance_find_msgs( # move on to the next task of the same sender. continue - # Schedule task for transfer from sender to recipient + # Schedule task for transfer from sender to receiver msgs.append((snd_ws, rec_ws, ts)) # *_bytes_max/min are all negative for heap sorting @@ -6525,7 +6659,7 @@ def _rebalance_find_msgs( else: heapq.heappop(senders) - # If recipient still has bytes to gain, push it back into the recipients + # If receiver still has bytes to gain, push it back into the receivers # heap; it may or may not come back on top again. if rec_bytes_min < 0: # See definition of recipients above @@ -6550,17 +6684,13 @@ async def _rebalance_move_data( self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" ) -> dict: """Perform the actual transfer of data across the network in rebalance(). - Takes in input the output of _rebalance_find_msgs(), that is a list of tuples: - - - sender worker - - recipient worker - - task to be transferred + Takes in input the output of _rebalance_find_msgs(). FIXME this method is not robust when the cluster is not idle. """ + ts: TaskState snd_ws: WorkerState rec_ws: WorkerState - ts: TaskState to_recipients: "defaultdict[str, defaultdict[str, list[str]]]" = defaultdict( lambda: defaultdict(list) @@ -6581,17 +6711,17 @@ async def _rebalance_move_data( ) to_senders = defaultdict(list) - for snd_ws, rec_ws, ts in msgs: - if ts._key not in failed_keys_by_recipient[rec_ws.address]: - to_senders[snd_ws.address].append(ts._key) + for sender, recipient, ts in msgs: + to_recipients[recipient.address][ts._key].append(sender.address) + to_senders[sender.address].append(ts._key) # Note: this never raises exceptions await asyncio.gather( *(self.delete_worker_data(r, v) for r, v in to_senders.items()) ) - for r, v in to_recipients.items(): self.log_event(r, {"action": "rebalance", "who_has": v}) + self.log_event( "all", { @@ -6602,11 +6732,31 @@ async def _rebalance_move_data( }, ) - missing_keys = {k for r in failed_keys_by_recipient.values() for k in r} - if missing_keys: - return {"status": "partial-fail", "keys": list(missing_keys)} - else: - return {"status": "OK"} + if any(r["status"] != "OK" for r in result): + return { + "status": "missing-data", + "keys": list( + concat( + r["keys"].keys() + for r in result + if r["status"] == "missing-data" + ) + ), + } + + for snd_ws, rec_ws, ts in msgs: + assert ts._state == "memory" + ts._who_has.add(rec_ws) + rec_ws._has_what[ts] = None + rec_ws.nbytes += ts.get_nbytes() + self.log.append( + ("rebalance", ts._key, time(), snd_ws.address, rec_ws.address) + ) + + await asyncio.gather( + *(self._delete_worker_data(r, v) for r, v in to_senders.items()) + ) + return {"status": "OK"} async def replicate( self, @@ -6639,7 +6789,7 @@ async def replicate( -------- Scheduler.rebalance """ - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState wws: WorkerState ts: TaskState @@ -6647,10 +6797,12 @@ async def replicate( assert branching_factor > 0 async with self._lock if lock else empty_context: if workers is not None: - workers = {parent._workers_dv[w] for w in self.workers_list(workers)} + workers = { + self.state._workers_dv[w] for w in self.workers_list(workers) + } workers = {ws for ws in workers if ws._status == Status.running} else: - workers = parent._running + workers = self.state._running if n is None: n = len(workers) @@ -6659,10 +6811,10 @@ async def replicate( if n == 0: raise ValueError("Can not use replicate to delete data") - tasks = {parent._tasks[k] for k in keys} + tasks = {self.state._tasks[k] for k in keys} missing_data = [ts._key for ts in tasks if not ts._who_has] if missing_data: - return {"status": "partial-fail", "keys": missing_data} + return {"status": "missing-data", "keys": missing_data} # Delete extraneous data if delete: @@ -6675,7 +6827,6 @@ async def replicate( ): del_worker_tasks[ws].add(ts) - # Note: this never raises exceptions await asyncio.gather( *[ self.delete_worker_data(ws._address, [t.key for t in tasks]) @@ -6705,15 +6856,20 @@ async def replicate( wws._address for wws in ts._who_has ] - await asyncio.gather( + results = await asyncio.gather( *( # Note: this never raises exceptions self.gather_on_worker(w, who_has) for w, who_has in gathers.items() ) ) - for r, v in gathers.items(): - self.log_event(r, {"action": "replicate-add", "who_has": v}) + for w, v in zip(gathers, results): + if v["status"] == "OK": + self.add_keys(worker=w, keys=list(gathers[w])) + else: + logger.warning("Communication failed during replication: %s", v) + + self.log_event(w, {"action": "replicate-add", "keys": gathers[w]}) self.log_event( "all", @@ -6792,20 +6948,22 @@ def workers_to_close( -------- Scheduler.retire_workers """ - parent: SchedulerState = cast(SchedulerState, self) + if target is not None and n is None: - n = len(parent._workers_dv) - target + n = len(self.state._workers_dv) - target if n is not None: if n < 0: n = 0 - target = len(parent._workers_dv) - n + target = len(self.state._workers_dv) - n if n is None and memory_ratio is None: memory_ratio = 2 ws: WorkerState with log_errors(): - if not n and all([ws._processing for ws in parent._workers_dv.values()]): + if not n and all( + [ws._processing for ws in self.state._workers_dv.values()] + ): return [] if key is None: @@ -6815,7 +6973,7 @@ def workers_to_close( ): key = pickle.loads(key) - groups = groupby(key, parent._workers.values()) + groups = groupby(key, self.state._workers.values()) limit_bytes = { k: sum([ws._memory_limit for ws in v]) for k, v in groups.items() @@ -6834,7 +6992,7 @@ def _key(group): idle = sorted(groups, key=_key) to_close = [] - n_remain = len(parent._workers_dv) + n_remain = len(self.state._workers_dv) while idle: group = idle.pop() @@ -6902,7 +7060,7 @@ async def retire_workers( -------- Scheduler.workers_to_close """ - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState ts: TaskState with log_errors(): @@ -6922,18 +7080,18 @@ async def retire_workers( names_set = {str(name) for name in names} wss = { ws - for ws in parent._workers_dv.values() + for ws in self.state._workers_dv.values() if str(ws._name) in names_set } elif workers is not None: wss = { - parent._workers_dv[address] + self.state._workers_dv[address] for address in workers - if address in parent._workers_dv + if address in self.state._workers_dv } else: wss = { - parent._workers_dv[address] + self.state._workers_dv[address] for address in self.workers_to_close(**kwargs) } if not wss: @@ -6998,7 +7156,6 @@ async def _track_retire_worker( close_workers: bool, remove: bool, ) -> tuple: # tuple[str | None, dict] - parent: SchedulerState = cast(SchedulerState, self) while not policy.done(): if policy.no_recipients: @@ -7019,7 +7176,7 @@ async def _track_retire_worker( "All unique keys on worker %s have been replicated elsewhere", ws._address ) - if close_workers and ws._address in parent._workers_dv: + if close_workers and ws._address in self.state._workers_dv: await self.close_worker(worker=ws._address, safe=True) if remove: await self.remove_worker(address=ws._address, safe=True) @@ -7034,16 +7191,16 @@ def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None): This should not be used in practice and is mostly here for legacy reasons. However, it is sent by workers from time to time. """ - parent: SchedulerState = cast(SchedulerState, self) - if worker not in parent._workers_dv: + + if worker not in self.state._workers_dv: return "not found" - ws: WorkerState = parent._workers_dv[worker] + ws: WorkerState = self.state._workers_dv[worker] redundant_replicas = [] for key in keys: - ts: TaskState = parent._tasks.get(key) + ts: TaskState = self.state._tasks.get(key) if ts is not None and ts._state == "memory": if ws not in ts._who_has: - parent.add_replica(ts, ws) + self.state.add_replica(ts, ws) else: redundant_replicas.append(key) @@ -7077,7 +7234,7 @@ def update_data( -------- Scheduler.mark_key_in_memory """ - parent: SchedulerState = cast(SchedulerState, self) + with log_errors(): who_has = { k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() @@ -7085,18 +7242,18 @@ def update_data( logger.debug("Update data %s", who_has) for key, workers in who_has.items(): - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = self.state._tasks.get(key) # type: ignore if ts is None: - ts = parent.new_task(key, None, "memory") + ts = self.state.new_task(key, None, "memory") ts.state = "memory" ts_nbytes = nbytes.get(key, -1) if ts_nbytes >= 0: ts.set_nbytes(ts_nbytes) for w in workers: - ws: WorkerState = parent._workers_dv[w] + ws: WorkerState = self.state._workers_dv[w] if ws not in ts._who_has: - parent.add_replica(ts, ws) + self.state.add_replica(ts, ws) self.report( {"op": "key-in-memory", "key": key, "workers": list(workers)} ) @@ -7105,9 +7262,9 @@ def update_data( self.client_desires_keys(keys=list(who_has), client=client) def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): - parent: SchedulerState = cast(SchedulerState, self) + if ts is None: - ts = parent._tasks.get(key) + ts = self.state._tasks.get(key) elif key is None: key = ts._key else: @@ -7118,7 +7275,7 @@ def report_on_key(self, key: str = None, ts: TaskState = None, client: str = Non if ts is None: report_msg = {"op": "cancelled-key", "key": key} else: - report_msg = _task_to_report_msg(parent, ts) + report_msg = _task_to_report_msg(self.state, ts) if report_msg is not None: self.report(report_msg, ts=ts, client=client) @@ -7177,79 +7334,80 @@ def subscribe_worker_status(self, comm=None): return ident def get_processing(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState ts: TaskState if workers is not None: workers = set(map(self.coerce_address, workers)) return { - w: [ts._key for ts in parent._workers_dv[w].processing] for w in workers + w: [ts._key for ts in self.state._workers_dv[w].processing] + for w in workers } else: return { w: [ts._key for ts in ws._processing] - for w, ws in parent._workers_dv.items() + for w, ws in self.state._workers_dv.items() } def get_who_has(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState ts: TaskState if keys is not None: return { - k: [ws._address for ws in parent._tasks[k].who_has] - if k in parent._tasks + k: [ws._address for ws in self.state._tasks[k].who_has] + if k in self.state._tasks else [] for k in keys } else: return { key: [ws._address for ws in ts._who_has] - for key, ts in parent._tasks.items() + for key, ts in self.state._tasks.items() } def get_has_what(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState ts: TaskState if workers is not None: workers = map(self.coerce_address, workers) return { - w: [ts._key for ts in parent._workers_dv[w].has_what] - if w in parent._workers_dv + w: [ts._key for ts in self.state._workers_dv[w].has_what] + if w in self.state._workers_dv else [] for w in workers } else: return { w: [ts._key for ts in ws.has_what] - for w, ws in parent._workers_dv.items() + for w, ws in self.state._workers_dv.items() } def get_ncores(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState if workers is not None: workers = map(self.coerce_address, workers) return { - w: parent._workers_dv[w].nthreads + w: self.state._workers_dv[w].nthreads for w in workers - if w in parent._workers_dv + if w in self.state._workers_dv } else: - return {w: ws._nthreads for w, ws in parent._workers_dv.items()} + return {w: ws._nthreads for w, ws in self.state._workers_dv.items()} def get_ncores_running(self, comm=None, workers=None): - parent: SchedulerState = cast(SchedulerState, self) + ncores = self.get_ncores(workers=workers) return { w: n for w, n in ncores.items() - if parent._workers_dv[w].status == Status.running + if self.state._workers_dv[w].status == Status.running } async def get_call_stack(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState dts: TaskState if keys is not None: @@ -7257,7 +7415,7 @@ async def get_call_stack(self, comm=None, keys=None): processing = set() while stack: key = stack.pop() - ts = parent._tasks[key] + ts = self.state._tasks[key] if ts._state == "waiting": stack.extend([dts._key for dts in ts._dependencies]) elif ts._state == "processing": @@ -7268,7 +7426,7 @@ async def get_call_stack(self, comm=None, keys=None): if ts._processing_on: workers[ts._processing_on.address].append(ts._key) else: - workers = {w: None for w in parent._workers_dv} + workers = {w: None for w in self.state._workers_dv} if not workers: return {} @@ -7280,14 +7438,16 @@ async def get_call_stack(self, comm=None, keys=None): return response def get_nbytes(self, comm=None, keys=None, summary=True): - parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState with log_errors(): if keys is not None: - result = {k: parent._tasks[k].nbytes for k in keys} + result = {k: self.state._tasks[k].nbytes for k in keys} else: result = { - k: ts._nbytes for k, ts in parent._tasks.items() if ts._nbytes >= 0 + k: ts._nbytes + for k, ts in self.state._tasks.items() + if ts._nbytes >= 0 } if summary: @@ -7318,8 +7478,8 @@ def run_function(self, stream, function, args=(), kwargs={}, wait=True): return run(self, stream, function=function, args=args, kwargs=kwargs, wait=wait) def set_metadata(self, comm=None, keys=None, value=None): - parent: SchedulerState = cast(SchedulerState, self) - metadata = parent._task_metadata + + metadata = self.state._task_metadata for key in keys[:-1]: if key not in metadata or not isinstance(metadata[key], (dict, list)): metadata[key] = {} @@ -7327,8 +7487,8 @@ def set_metadata(self, comm=None, keys=None, value=None): metadata[keys[-1]] = value def get_metadata(self, comm=None, keys=None, default=no_default): - parent: SchedulerState = cast(SchedulerState, self) - metadata = parent._task_metadata + + metadata = self.state._task_metadata for key in keys[:-1]: metadata = metadata[key] try: @@ -7368,9 +7528,9 @@ def get_task_prefix_states(self, comm=None): return state def get_task_status(self, comm=None, keys=None): - parent: SchedulerState = cast(SchedulerState, self) + return { - key: (parent._tasks[key].state if key in parent._tasks else None) + key: (self.state._tasks[key].state if key in self.state._tasks else None) for key in keys } @@ -7461,11 +7621,11 @@ def transition(self, key, finish: str, *args, **kwargs): -------- Scheduler.transitions: transitive version of this function """ - parent: SchedulerState = cast(SchedulerState, self) + recommendations: dict worker_msgs: dict client_msgs: dict - a: tuple = parent._transition(key, finish, *args, **kwargs) + a: tuple = self.state._transition(key, finish, args=args, kwargs=kwargs) recommendations, client_msgs, worker_msgs = a self.send_all(client_msgs, worker_msgs) return recommendations @@ -7476,17 +7636,19 @@ def transitions(self, recommendations: dict): This includes feedback from previous transitions and continues until we reach a steady state """ - parent: SchedulerState = cast(SchedulerState, self) + client_msgs: dict = {} worker_msgs: dict = {} - parent._transitions(recommendations, client_msgs, worker_msgs) + self.state._transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) def story(self, *keys): """Get all transitions that touch one of the input keys""" keys = {key.key if isinstance(key, TaskState) else key for key in keys} return [ - t for t in self.transition_log if t[0] in keys or keys.intersection(t[3]) + t + for t in self.state.transition_log + if t[0] in keys or keys.intersection(t[3]) ] transition_story = story @@ -7497,10 +7659,10 @@ def reschedule(self, key=None, worker=None): Things may have shifted and this task may now be better suited to run elsewhere """ - parent: SchedulerState = cast(SchedulerState, self) + ts: TaskState try: - ts = parent._tasks[key] + ts = self.state._tasks[key] except KeyError: logger.warning( "Attempting to reschedule task {}, which was not " @@ -7518,26 +7680,26 @@ def reschedule(self, key=None, worker=None): ##################### def add_resources(self, comm=None, worker=None, resources=None): - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv[worker] + + ws: WorkerState = self.state._workers_dv[worker] if resources: ws._resources.update(resources) ws._used_resources = {} for resource, quantity in ws._resources.items(): ws._used_resources[resource] = 0 - dr: dict = parent._resources.get(resource, None) + dr: dict = self.state._resources.get(resource, None) if dr is None: - parent._resources[resource] = dr = {} + self.state._resources[resource] = dr = {} dr[worker] = quantity return "OK" def remove_resources(self, worker): - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv[worker] + + ws: WorkerState = self.state._workers_dv[worker] for resource, quantity in ws._resources.items(): - dr: dict = parent._resources.get(resource, None) + dr: dict = self.state._resources.get(resource, None) if dr is None: - parent._resources[resource] = dr = {} + self.state._resources[resource] = dr = {} del dr[worker] def coerce_address(self, addr, resolve=True): @@ -7548,9 +7710,9 @@ def coerce_address(self, addr, resolve=True): Handles strings, tuples, or aliases. """ # XXX how many address-parsing routines do we have? - parent: SchedulerState = cast(SchedulerState, self) - if addr in parent._aliases: - addr = parent._aliases[addr] + + if addr in self.state._aliases: + addr = self.state._aliases[addr] if isinstance(addr, tuple): addr = unparse_host_port(*addr) if not isinstance(addr, str): @@ -7570,16 +7732,18 @@ def workers_list(self, workers): Takes a list of worker addresses or hostnames. Returns a list of all worker addresses that match """ - parent: SchedulerState = cast(SchedulerState, self) + if workers is None: - return list(parent._workers) + return list(self.state._workers) out = set() for w in workers: if ":" in w: out.add(w) else: - out.update({ww for ww in parent._workers if w in ww}) # TODO: quadratic + out.update( + {ww for ww in self.state._workers if w in ww} + ) # TODO: quadratic return list(out) def start_ipython(self, comm=None): @@ -7606,11 +7770,11 @@ async def get_profile( stop=None, key=None, ): - parent: SchedulerState = cast(SchedulerState, self) + if workers is None: - workers = parent._workers_dv + workers = self.state._workers_dv else: - workers = set(parent._workers_dv) & set(workers) + workers = set(self.state._workers_dv) & set(workers) if scheduler: return profile.get_profile(self.io_loop.profile, start=start, stop=stop) @@ -7640,16 +7804,16 @@ async def get_profile_metadata( stop=None, profile_cycle_interval=None, ): - parent: SchedulerState = cast(SchedulerState, self) + dt = profile_cycle_interval or dask.config.get( "distributed.worker.profile.cycle" ) dt = parse_timedelta(dt, default="ms") if workers is None: - workers = parent._workers_dv + workers = self.state._workers_dv else: - workers = set(parent._workers_dv) & set(workers) + workers = set(self.state._workers_dv) & set(workers) results = await asyncio.gather( *(self.rpc(w).profile_metadata(start=start, stop=stop) for w in workers), return_exceptions=True, @@ -7682,10 +7846,8 @@ async def get_profile_metadata( return {"counts": counts, "keys": keys} - async def performance_report( - self, comm=None, start=None, last_count=None, code="", mode=None - ): - parent: SchedulerState = cast(SchedulerState, self) + async def performance_report(self, comm=None, start=None, last_count=None, code=""): + stop = time() # Profiles compute, scheduler, workers = await asyncio.gather( @@ -7783,10 +7945,10 @@ def profile_to_figure(state): ntasks=total_tasks, tasks_timings=tasks_timings, address=self.address, - nworkers=len(parent._workers_dv), - threads=sum([ws._nthreads for ws in parent._workers_dv.values()]), + nworkers=len(self.state._workers_dv), + threads=sum([ws._nthreads for ws in self.state._workers_dv.values()]), memory=format_bytes( - sum([ws._memory_limit for ws in parent._workers_dv.values()]) + sum([ws._memory_limit for ws in self.state._workers_dv.values()]) ), code=code, dask_version=dask.__version__, @@ -7836,7 +7998,7 @@ def profile_to_figure(state): from bokeh.plotting import output_file, save with tmpfile(extension=".html") as fn: - output_file(filename=fn, title="Dask Performance Report", mode=mode) + output_file(filename=fn, title="Dask Performance Report") template_directory = os.path.join( os.path.dirname(os.path.abspath(__file__)), "dashboard", "templates" ) @@ -7892,16 +8054,16 @@ def get_events(self, comm=None, topic=None): return valmap(tuple, self.events) async def get_worker_monitor_info(self, recent=False, starts=None): - parent: SchedulerState = cast(SchedulerState, self) + if starts is None: starts = {} results = await asyncio.gather( *( self.rpc(w).get_monitor_info(recent=recent, start=starts.get(w, 0)) - for w in parent._workers_dv + for w in self.state._workers_dv ) ) - return dict(zip(parent._workers_dv, results)) + return dict(zip(self.state._workers_dv, results)) ########### # Cleanup # @@ -7922,7 +8084,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): lets us avoid this fringe optimization when we have better things to think about. """ - parent: SchedulerState = cast(SchedulerState, self) + try: if self.status == Status.closed: return @@ -7930,7 +8092,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): next_time = timedelta(seconds=0.1) if self.proc.cpu_percent() < 50: - workers: list = list(parent._workers.values()) + workers: list = list(self.state._workers.values()) nworkers: Py_ssize_t = len(workers) i: Py_ssize_t for i in range(nworkers): @@ -7939,7 +8101,7 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): try: if ws is None or not ws._processing: continue - parent._reevaluate_occupancy_worker(ws) + self.state._reevaluate_occupancy_worker(ws) finally: del ws # lose ref @@ -7957,12 +8119,13 @@ def reevaluate_occupancy(self, worker_index: Py_ssize_t = 0): raise async def check_worker_ttl(self): - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState now = time() - for ws in parent._workers_dv.values(): + for ws in self.state._workers_dv.values(): if (ws._last_seen < now - self.worker_ttl) and ( - ws._last_seen < now - 10 * heartbeat_interval(len(parent._workers_dv)) + ws._last_seen + < now - 10 * heartbeat_interval(len(self.state._workers_dv)) ): logger.warning( "Worker failed to heartbeat within %s seconds. Closing: %s", @@ -7972,11 +8135,11 @@ async def check_worker_ttl(self): await self.remove_worker(address=ws._address) def check_idle(self): - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState if ( - any([ws._processing for ws in parent._workers_dv.values()]) - or parent._unrunnable + any([ws._processing for ws in self.state._workers_dv.values()]) + or self.state._unrunnable ): self.idle_since = None return @@ -8006,20 +8169,20 @@ def adaptive_target(self, comm=None, target_duration=None): -------- distributed.deploy.Adaptive """ - parent: SchedulerState = cast(SchedulerState, self) + if target_duration is None: target_duration = dask.config.get("distributed.adaptive.target-duration") target_duration = parse_timedelta(target_duration) # CPU cpu = math.ceil( - parent._total_occupancy / target_duration + self.state._total_occupancy / target_duration ) # TODO: threads per worker # Avoid a few long tasks from asking for many cores ws: WorkerState tasks_processing = 0 - for ws in parent._workers_dv.values(): + for ws in self.state._workers_dv.values(): tasks_processing += len(ws._processing) if tasks_processing > cpu: @@ -8027,35 +8190,35 @@ def adaptive_target(self, comm=None, target_duration=None): else: cpu = min(tasks_processing, cpu) - if parent._unrunnable and not parent._workers_dv: + if self.state._unrunnable and not self.state._workers_dv: cpu = max(1, cpu) # add more workers if more than 60% of memory is used - limit = sum([ws._memory_limit for ws in parent._workers_dv.values()]) - used = sum([ws._nbytes for ws in parent._workers_dv.values()]) + limit = sum([ws._memory_limit for ws in self.state._workers_dv.values()]) + used = sum([ws._nbytes for ws in self.state._workers_dv.values()]) memory = 0 if used > 0.6 * limit and limit > 0: - memory = 2 * len(parent._workers_dv) + memory = 2 * len(self.state._workers_dv) target = max(memory, cpu) - if target >= len(parent._workers_dv): + if target >= len(self.state._workers_dv): return target else: # Scale down? to_close = self.workers_to_close() - return len(parent._workers_dv) - len(to_close) + return len(self.state._workers_dv) - len(to_close) def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str): """Asynchronously ask a worker to acquire a replica of the listed keys from other workers. This is a fire-and-forget operation which offers no feedback for success or failure, and is intended for housekeeping and not for computation. """ - parent: SchedulerState = cast(SchedulerState, self) + ws: WorkerState ts: TaskState who_has = {} for key in keys: - ts = parent._tasks[key] + ts = self.state._tasks[key] who_has[key] = {ws._address for ws in ts._who_has} self.stream_comms[addr].send( @@ -8078,15 +8241,15 @@ def request_remove_replicas(self, addr: str, keys: list, *, stimulus_id: str): to re-add itself to who_has. If the worker agrees to discard the task, there is no feedback. """ - parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv[addr] + + ws: WorkerState = self.state._workers_dv[addr] validate = self.validate # The scheduler immediately forgets about the replica and suggests the worker to # drop it. The worker may refuse, at which point it will send back an add-keys # message to reinstate it. for key in keys: - ts: TaskState = parent._tasks[key] + ts: TaskState = self.state._tasks[key] if validate: # Do not destroy the last copy assert len(ts._who_has) > 1 @@ -8389,6 +8552,7 @@ def decide_worker( return ws +@ccall def validate_task_state(ts: TaskState): """ Validate the given TaskState. @@ -8445,7 +8609,7 @@ def validate_task_state(ts: TaskState): assert dts._state != "forgotten" assert (ts._processing_on is not None) == (ts._state == "processing") - assert bool(ts._who_has) == (ts._state == "memory"), (ts, ts._who_has, ts._state) + assert (not not ts._who_has) == (ts._state == "memory"), (ts, ts._who_has) if ts._state == "processing": assert all([dts._who_has for dts in ts._dependencies]), ( @@ -8490,6 +8654,7 @@ def validate_task_state(ts: TaskState): assert ts in ts._processing_on.actors +@ccall def validate_worker_state(ws: WorkerState): ts: TaskState for ts in ws._has_what: @@ -8504,6 +8669,7 @@ def validate_worker_state(ws: WorkerState): assert ts._state in ("memory", "processing") +@ccall def validate_state(tasks, workers, clients): """ Validate a current runtime state diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 4612d9179d..31d1d95a21 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3872,7 +3872,7 @@ async def test_idempotence(s, a, b): # Submit x = c.submit(inc, 1) await x - log = list(s.transition_log) + log = list(s.state.transition_log) len_single_submit = len(log) # see last assert @@ -3880,29 +3880,29 @@ async def test_idempotence(s, a, b): assert x.key == y.key await y await asyncio.sleep(0.1) - log2 = list(s.transition_log) + log2 = list(s.state.transition_log) assert log == log2 # Error a = c.submit(div, 1, 0) await wait(a) assert a.status == "error" - log = list(s.transition_log) + log = list(s.state.transition_log) b = f.submit(div, 1, 0) assert a.key == b.key await wait(b) await asyncio.sleep(0.1) - log2 = list(s.transition_log) + log2 = list(s.state.transition_log) assert log == log2 - s.transition_log.clear() + s.state.transition_log.clear() # Simultaneous Submit d = c.submit(inc, 2) e = c.submit(inc, 2) await wait([d, e]) - assert len(s.transition_log) == len_single_submit + assert len(s.state.transition_log) == len_single_submit await c.close() await f.close() @@ -4566,7 +4566,7 @@ def test_auto_normalize_collection_sync(c): def assert_no_data_loss(scheduler): - for key, start, finish, recommendations, _ in scheduler.transition_log: + for key, start, finish, recommendations, _ in scheduler.state.transition_log: if start == "memory" and finish == "released": for k, v in recommendations.items(): assert not (k == key and v == "waiting") @@ -5557,7 +5557,7 @@ def fib(x): future = c.submit(fib, 8) result = await future assert result == 21 - assert len(s.transition_log) > 50 + assert len(s.state.transition_log) > 50 @gen_cluster(client=True) @@ -6873,7 +6873,7 @@ def test_computation_object_code_dask_compute(client): test_function_code = inspect.getsource(test_computation_object_code_dask_compute) def fetch_comp_code(dask_scheduler): - computations = list(dask_scheduler.computations) + computations = list(dask_scheduler.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -6913,7 +6913,7 @@ async def test_computation_object_code_dask_persist(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_dask_persist.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -6933,7 +6933,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_simple.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6954,7 +6954,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_list_comp.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6976,7 +6976,7 @@ def func(x): test_function_code = inspect.getsource( test_computation_object_code_client_submit_dict_comp.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] @@ -6996,7 +6996,7 @@ async def test_computation_object_code_client_map(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_client_map.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 @@ -7014,7 +7014,7 @@ async def test_computation_object_code_client_compute(c, s, a, b): test_function_code = inspect.getsource( test_computation_object_code_client_compute.__wrapped__ ) - computations = list(s.computations) + computations = list(s.state.computations) assert len(computations) == 1 comp = computations[0] assert len(comp.code) == 1 diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6d5081ef0b..3a4e416f22 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -734,11 +734,11 @@ async def test_coerce_address(s): assert s.coerce_address(123) == b.address assert s.coerce_address("charlie") == c.address - assert s.coerce_hostname("127.0.0.1") == "127.0.0.1" - assert s.coerce_hostname("alice") == a.ip - assert s.coerce_hostname(123) == b.ip - assert s.coerce_hostname("charlie") == c.ip - assert s.coerce_hostname("jimmy") == "jimmy" + assert s.state.coerce_hostname("127.0.0.1") == "127.0.0.1" + assert s.state.coerce_hostname("alice") == a.ip + assert s.state.coerce_hostname(123) == b.ip + assert s.state.coerce_hostname("charlie") == c.ip + assert s.state.coerce_hostname("jimmy") == "jimmy" assert s.coerce_address("zzzt:8000", resolve=False) == "tcp://zzzt:8000" await asyncio.gather(a.close(), b.close(), c.close()) @@ -796,11 +796,11 @@ async def test_story(c, s, a, b): f = c.persist(y) await wait([f]) - assert s.transition_log + assert s.state.transition_log story = s.story(x.key) - assert all(line in s.transition_log for line in story) - assert len(story) < len(s.transition_log) + assert all(line in s.state.transition_log for line in story) + assert len(story) < len(s.state.transition_log) assert all(x.key == line[0] or x.key in line[-2] for line in story) assert len(s.story(x.key, y.key)) > len(story) @@ -1561,13 +1561,13 @@ async def test_dont_recompute_if_persisted(c, s, a, b): yy = y.persist() await wait(yy) - old = list(s.transition_log) + old = list(s.state.transition_log) yyy = y.persist() await wait(yyy) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster(client=True) @@ -1598,12 +1598,12 @@ async def test_dont_recompute_if_persisted_3(c, s, a, b): ww = w.persist() await wait(ww) - old = list(s.transition_log) + old = list(s.state.transition_log) www = w.persist() await wait(www) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster(client=True) @@ -1650,13 +1650,13 @@ async def test_dont_recompute_if_erred(c, s, a, b): yy = y.persist() await wait(yy) - old = list(s.transition_log) + old = list(s.state.transition_log) yyy = y.persist() await wait(yyy) await asyncio.sleep(0.100) - assert list(s.transition_log) == old + assert list(s.state.transition_log) == old @gen_cluster() @@ -3172,16 +3172,16 @@ async def test_computations(c, s, a, b): z = (x - 2).persist() await z - assert len(s.computations) == 2 - assert "add" in str(s.computations[0].groups) - assert "sub" in str(s.computations[1].groups) - assert "sub" not in str(s.computations[0].groups) + assert len(s.state.computations) == 2 + assert "add" in str(s.state.computations[0].groups) + assert "sub" in str(s.state.computations[1].groups) + assert "sub" not in str(s.state.computations[0].groups) - assert isinstance(repr(s.computations[1]), str) + assert isinstance(repr(s.state.computations[1]), str) - assert s.computations[1].stop == max(tg.stop for tg in s.task_groups.values()) + assert s.state.computations[1].stop == max(tg.stop for tg in s.task_groups.values()) - assert s.computations[0].states["memory"] == y.npartitions + assert s.state.computations[0].states["memory"] == y.npartitions @gen_cluster(client=True) @@ -3190,7 +3190,7 @@ async def test_computations_futures(c, s, a, b): total = c.submit(sum, futures) await total - [computation] = s.computations + [computation] = s.state.computations assert "sum" in str(computation.groups) assert "inc" in str(computation.groups)