Skip to content

Commit

Permalink
__dask_distributed_pack__(): client argument (#4248)
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk authored Nov 18, 2020
1 parent 04a6b78 commit eda9bcc
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2581,7 +2581,7 @@ def _graph_to_futures(
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())

dsk = highlevelgraph_pack(dsk, keyset, self, self.futures)
dsk = highlevelgraph_pack(dsk, self, keyset)

if isinstance(retries, Number) and retries > 0:
retries = {k: retries for k in dsk}
Expand Down
16 changes: 6 additions & 10 deletions distributed/protocol/highlevelgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ def _materialized_layer_pack(
layer: Layer,
all_keys,
known_key_dependencies,
client,
client_keys,
allowed_client,
allowed_futures,
):
from ..client import Future

Expand All @@ -47,11 +46,11 @@ def _materialized_layer_pack(
dsk = {k: unpack_remotedata(v, byte_keys=True) for k, v in layer.items()}
unpacked_futures = set.union(*[v[1] for v in dsk.values()]) if dsk else set()
for future in unpacked_futures:
if future.client is not allowed_client:
if future.client is not client:
raise ValueError(
"Inputs contain futures that were created by another client."
)
if tokey(future.key) not in allowed_futures:
if tokey(future.key) not in client.futures:
raise CancelledError(tokey(future.key))
unpacked_futures_deps = {}
for k, v in dsk.items():
Expand All @@ -76,15 +75,13 @@ def _materialized_layer_pack(
return {"dsk": dsk, "dependencies": dependencies}


def highlevelgraph_pack(
hlg: HighLevelGraph, client_keys, allowed_client, allowed_futures
):
def highlevelgraph_pack(hlg: HighLevelGraph, client, client_keys):
layers = []

# Dump each layer (in topological order)
for layer in (hlg.layers[name] for name in hlg._toposort_layers()):
if not layer.is_materialized():
state = layer.__dask_distributed_pack__()
state = layer.__dask_distributed_pack__(client)
if state is not None:
layers.append(
{
Expand All @@ -104,9 +101,8 @@ def highlevelgraph_pack(
layer,
hlg.get_all_external_keys(),
hlg.key_dependencies,
client,
client_keys,
allowed_client,
allowed_futures,
),
}
)
Expand Down

0 comments on commit eda9bcc

Please sign in to comment.