Skip to content

Commit

Permalink
[Core] Refactoring Ray DAG object scanner (ray-project#26917)
Browse files Browse the repository at this point in the history
* make sure Ray DAG can work with minimal install

Signed-off-by: Siyuan Zhuang <[email protected]>\
Signed-off-by: Rohan138 <[email protected]>
  • Loading branch information
suquark authored and Rohan138 committed Jul 28, 2022
1 parent c6e0ed2 commit 7f7f4a7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 52 deletions.
8 changes: 8 additions & 0 deletions python/ray/dag/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""This module defines the base class for object scanning and gets rid of
reference cycles."""
from ray.util.annotations import DeveloperAPI


@DeveloperAPI
class DAGNodeBase:
"""Common base class for a node in a Ray task graph."""
3 changes: 2 additions & 1 deletion python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ray
from ray.dag.base import DAGNodeBase
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.util.annotations import DeveloperAPI

Expand All @@ -18,7 +19,7 @@


@DeveloperAPI
class DAGNode:
class DAGNode(DAGNodeBase):
"""Abstract class for a node in a Ray task graph.
A node has a type (e.g., FunctionNode), data (e.g., function options and
Expand Down
80 changes: 29 additions & 51 deletions python/ray/dag/py_obj_scanner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ray

import uuid
import io
import sys

Expand All @@ -13,13 +12,23 @@
else:
import pickle # noqa: F401

from typing import List, Dict, Any, TypeVar, TYPE_CHECKING

if TYPE_CHECKING:
from ray.dag.dag_node import DAGNode
from typing import List, Dict, Any, TypeVar
from ray.dag.base import DAGNodeBase

T = TypeVar("T")

# Used in deserialization hooks to reference scanner instances.
_instances: Dict[int, "_PyObjScanner"] = {}


def _get_node(instance_id: int, node_index: int) -> DAGNodeBase:
"""Get the node instance.
Note: This function should be static and globally importable,
otherwise the serialization overhead would be very significant.
"""
return _instances[instance_id]._replace_index(node_index)


class _PyObjScanner(ray.cloudpickle.CloudPickler):
"""Utility to find and replace DAGNodes in Python objects.
Expand All @@ -29,50 +38,26 @@ class _PyObjScanner(ray.cloudpickle.CloudPickler):
table and then replace the nodes via ``replace_nodes()``.
"""

# Used in deserialization hooks to reference scanner instances.
_instances: Dict[str, "_PyObjScanner"] = {}

def __init__(self):
# Buffer to keep intermediate serialized state.
self._buf = io.BytesIO()
# List of top-level DAGNodes found during the serialization pass.
self._found = None
# Replacement table to consult during deserialization.
self._replace_table: Dict["DAGNode", T] = None
# UUID of this scanner.
self._uuid = uuid.uuid4().hex
_PyObjScanner._instances[self._uuid] = self
# Register pickler override for DAGNode types.
from ray.dag.function_node import FunctionNode
from ray.dag.class_node import ClassNode, ClassMethodNode
from ray.dag.input_node import InputNode, InputAttributeNode
from ray.serve.deployment_node import DeploymentNode
from ray.serve.deployment_method_node import DeploymentMethodNode
from ray.serve.deployment_function_node import DeploymentFunctionNode
from ray.serve.deployment_executor_node import DeploymentExecutorNode
from ray.serve.deployment_method_executor_node import (
DeploymentMethodExecutorNode,
)
from ray.serve.deployment_function_executor_node import (
DeploymentFunctionExecutorNode,
)

self.dispatch_table[FunctionNode] = self._reduce_dag_node
self.dispatch_table[ClassNode] = self._reduce_dag_node
self.dispatch_table[ClassMethodNode] = self._reduce_dag_node
self.dispatch_table[InputNode] = self._reduce_dag_node
self.dispatch_table[InputAttributeNode] = self._reduce_dag_node
self.dispatch_table[DeploymentNode] = self._reduce_dag_node
self.dispatch_table[DeploymentMethodNode] = self._reduce_dag_node
self.dispatch_table[DeploymentFunctionNode] = self._reduce_dag_node

self.dispatch_table[DeploymentExecutorNode] = self._reduce_dag_node
self.dispatch_table[DeploymentMethodExecutorNode] = self._reduce_dag_node
self.dispatch_table[DeploymentFunctionExecutorNode] = self._reduce_dag_node

self._replace_table: Dict[DAGNodeBase, T] = None
_instances[id(self)] = self
super().__init__(self._buf)

def find_nodes(self, obj: Any) -> List["DAGNode"]:
def reducer_override(self, obj):
"""Hook for reducing objects."""
if isinstance(obj, DAGNodeBase):
index = len(self._found)
self._found.append(obj)
return _get_node, (id(self), index)

return super().reducer_override(obj)

def find_nodes(self, obj: Any) -> List[DAGNodeBase]:
"""Find top-level DAGNodes."""
assert (
self._found is None
Expand All @@ -81,22 +66,15 @@ def find_nodes(self, obj: Any) -> List["DAGNode"]:
self.dump(obj)
return self._found

def replace_nodes(self, table: Dict["DAGNode", T]) -> Any:
def replace_nodes(self, table: Dict[DAGNodeBase, T]) -> Any:
"""Replace previously found DAGNodes per the given table."""
assert self._found is not None, "find_nodes must be called first"
self._replace_table = table
self._buf.seek(0)
return pickle.load(self._buf)

def _replace_index(self, i: int) -> "DAGNode":
def _replace_index(self, i: int) -> DAGNodeBase:
return self._replace_table[self._found[i]]

def _reduce_dag_node(self, obj):
uuid = self._uuid
index = len(self._found)
res = (lambda i: _PyObjScanner._instances[uuid]._replace_index(i)), (index,)
self._found.append(obj)
return res

def __del__(self):
del _PyObjScanner._instances[self._uuid]
del _instances[id(self)]

0 comments on commit 7f7f4a7

Please sign in to comment.