Skip to content

Commit

Permalink
Add type annotations to various functions within distributed.worker (
Browse files Browse the repository at this point in the history
  • Loading branch information
orf authored Sep 14, 2021
1 parent 06835b1 commit 3f86e58
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ dask-worker-space/
.ycm_extra_conf.py
tags
.ipynb_checkpoints
.venv/
36 changes: 23 additions & 13 deletions distributed/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import bisect
import builtins
Expand All @@ -17,7 +19,10 @@
from datetime import timedelta
from inspect import isawaitable
from pickle import PicklingError
from typing import Dict, Hashable, Iterable, Optional
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, Optional

if TYPE_CHECKING:
from .client import Client

from tlz import first, keymap, merge, pluck # noqa: F401
from tornado.ioloop import IOLoop, PeriodicCallback
Expand Down Expand Up @@ -2818,8 +2823,14 @@ async def plugin_remove(self, comm=None, name=None):
return {"status": "OK"}

async def actor_execute(
self, comm=None, actor=None, function=None, args=(), kwargs={}
self,
comm=None,
actor=None,
function=None,
args=(),
kwargs: Optional[dict] = None,
):
kwargs = kwargs or {}
separate_thread = kwargs.pop("separate_thread", True)
key = actor
actor = self.actors[key]
Expand Down Expand Up @@ -2854,7 +2865,7 @@ def actor_attribute(self, comm=None, actor=None, attribute=None):
except Exception as ex:
return {"status": "error", "exception": to_serialize(ex)}

def meets_resource_constraints(self, key):
def meets_resource_constraints(self, key: str) -> bool:
ts = self.tasks[key]
if not ts.resource_restrictions:
return True
Expand Down Expand Up @@ -3264,8 +3275,7 @@ async def get_profile(
return prof

async def get_profile_metadata(self, comm=None, start=0, stop=None):
if stop is None:
add_recent = True
add_recent = stop is None
now = time() + self.scheduler_delay
stop = stop or now
start = start or 0
Expand Down Expand Up @@ -3447,14 +3457,14 @@ def validate_state(self):
#######################################

@property
def client(self):
def client(self) -> Client:
with self._lock:
if self._client:
return self._client
else:
return self._get_client()

def _get_client(self, timeout=None):
def _get_client(self, timeout=None) -> Client:
"""Get local client attached to this worker
If no such client exists, create one
Expand Down Expand Up @@ -3536,7 +3546,7 @@ def get_current_task(self):
return self.active_threads[threading.get_ident()]


def get_worker():
def get_worker() -> Worker:
"""Get the worker currently running this task
Examples
Expand All @@ -3563,7 +3573,7 @@ def get_worker():
raise ValueError("No workers found")


def get_client(address=None, timeout=None, resolve_address=True):
def get_client(address=None, timeout=None, resolve_address=True) -> Client:
"""Get a client while within a task.
This client connects to the same scheduler to which the worker is connected
Expand Down Expand Up @@ -3678,7 +3688,7 @@ class Reschedule(Exception):
"""


def parse_memory_limit(memory_limit, nthreads, total_cores=CPU_COUNT):
def parse_memory_limit(memory_limit, nthreads, total_cores=CPU_COUNT) -> Optional[int]:
if memory_limit is None:
return None

Expand Down Expand Up @@ -3807,7 +3817,7 @@ def execute_task(task):
_cache_lock = threading.Lock()


def dumps_function(func):
def dumps_function(func) -> bytes:
"""Dump a function to bytes, cache functions"""
try:
with _cache_lock:
Expand Down Expand Up @@ -4028,7 +4038,7 @@ def __repr__(self):
return msg


def convert_args_to_str(args, max_len=None):
def convert_args_to_str(args, max_len: Optional[int] = None) -> str:
"""Convert args to a string, allowing for some arguments to raise
exceptions during conversion and ignoring them.
"""
Expand All @@ -4047,7 +4057,7 @@ def convert_args_to_str(args, max_len=None):
return "({})".format(", ".join(strs))


def convert_kwargs_to_str(kwargs, max_len=None):
def convert_kwargs_to_str(kwargs: dict, max_len: Optional[int] = None) -> str:
"""Convert kwargs to a string, allowing for some arguments to raise
exceptions during conversion and ignoring them.
"""
Expand Down

0 comments on commit 3f86e58

Please sign in to comment.