diff --git a/distributed/client.py b/distributed/client.py index bbda46acf4..5ce872c064 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2581,7 +2581,7 @@ def _graph_to_futures( if not isinstance(dsk, HighLevelGraph): dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=()) - dsk = highlevelgraph_pack(dsk, keyset, self, self.futures) + dsk = highlevelgraph_pack(dsk, self, keyset) if isinstance(retries, Number) and retries > 0: retries = {k: retries for k in dsk} diff --git a/distributed/protocol/highlevelgraph.py b/distributed/protocol/highlevelgraph.py index 77de669ebb..512a7abd4b 100644 --- a/distributed/protocol/highlevelgraph.py +++ b/distributed/protocol/highlevelgraph.py @@ -27,9 +27,8 @@ def _materialized_layer_pack( layer: Layer, all_keys, known_key_dependencies, + client, client_keys, - allowed_client, - allowed_futures, ): from ..client import Future @@ -47,11 +46,11 @@ def _materialized_layer_pack( dsk = {k: unpack_remotedata(v, byte_keys=True) for k, v in layer.items()} unpacked_futures = set.union(*[v[1] for v in dsk.values()]) if dsk else set() for future in unpacked_futures: - if future.client is not allowed_client: + if future.client is not client: raise ValueError( "Inputs contain futures that were created by another client." ) - if tokey(future.key) not in allowed_futures: + if tokey(future.key) not in client.futures: raise CancelledError(tokey(future.key)) unpacked_futures_deps = {} for k, v in dsk.items(): @@ -76,15 +75,13 @@ def _materialized_layer_pack( return {"dsk": dsk, "dependencies": dependencies} -def highlevelgraph_pack( - hlg: HighLevelGraph, client_keys, allowed_client, allowed_futures -): +def highlevelgraph_pack(hlg: HighLevelGraph, client, client_keys): layers = [] # Dump each layer (in topological order) for layer in (hlg.layers[name] for name in hlg._toposort_layers()): if not layer.is_materialized(): - state = layer.__dask_distributed_pack__() + state = layer.__dask_distributed_pack__(client) if state is not None: layers.append( { @@ -104,9 +101,8 @@ def highlevelgraph_pack( layer, hlg.get_all_external_keys(), hlg.key_dependencies, + client, client_keys, - allowed_client, - allowed_futures, ), } )