Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Fixes ray workflow adapter to work with Ray 2.0 #189

Merged
merged 2 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graph_adapter_tests/h_ray/test_h_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@pytest.fixture(scope="module")
def init():
ray.init(local_mode=True) # need local mode, else it can't seem to find the h_ray module.
ray.init()
yield "initialized"
ray.shutdown()

Expand Down
4 changes: 2 additions & 2 deletions hamilton/dev_utils/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def __call__(self, fn: Callable):
TODO -- use @singledispatchmethod when we no longer support 3.6/3.7
https://docs.python.org/3/library/functools.html#functools.singledispatchmethod

@param fn: function (or class) to decorate
@return: The decorated function.
:param fn: function (or class) to decorate
:return: The decorated function.
"""
# In this case we just do a standard decorator
if isinstance(fn, types.FunctionType):
Expand Down
37 changes: 22 additions & 15 deletions hamilton/experimental/h_ray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import inspect
import logging
import typing

Expand All @@ -10,6 +11,22 @@
logger = logging.getLogger(__name__)


def raify(fn):
"""Makes the function into something ray-friendly.
This is necessary due to https://github.com/ray-project/ray/issues/28146.

:param fn: Function to make ray-friendly
:return: The ray-friendly version
"""
if isinstance(fn, functools.partial):

def new_fn(*args, **kwargs):
return fn(*args, **kwargs)

return new_fn
return fn


class RayGraphAdapter(base.HamiltonGraphAdapter, base.ResultMixin):
"""Class representing what's required to make Hamilton run on Ray

Expand Down Expand Up @@ -60,11 +77,7 @@ def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) ->
:param kwargs: the arguments that should be passed to it.
:return: returns a ray object reference.
"""
if isinstance(node.callable, functools.partial):
return functools.partial(
ray.remote(node.callable.func).remote, *node.callable.args, **node.callable.keywords
)(**kwargs)
return ray.remote(node.callable).remote(**kwargs)
return ray.remote(raify(node.callable)).remote(**kwargs)

def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
"""Builds the result and brings it back to this running process.
Expand Down Expand Up @@ -139,13 +152,7 @@ def check_node_type_equivalence(node_type: typing.Type, input_type: typing.Type)
return node_type == input_type

def execute_node(self, node: node.Node, kwargs: typing.Dict[str, typing.Any]) -> typing.Any:
"""Function that is called as we walk the graph to determine how to execute a hamilton function.

:param node: the node from the graph.
:param kwargs: the arguments that should be passed to it.
:return: returns a ray object reference.
"""
return workflow.step(node.callable).step(**kwargs)
return ray.remote(raify(node.callable)).bind(**kwargs)

def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
"""Builds the result and brings it back to this running process.
Expand All @@ -157,8 +164,8 @@ def build_result(self, **outputs: typing.Dict[str, typing.Any]) -> typing.Any:
for k, v in outputs.items():
logger.debug(f"Got output {k}, with type [{type(v)}].")
# need to wrap our result builder in a remote call and then pass in what we want to build from.
remote_combine = workflow.step(self.result_builder.build_result).step(**outputs)
result = remote_combine.run(
workflow_id=self.workflow_id
remote_combine = ray.remote(self.result_builder.build_result).bind(**outputs)
result = workflow.run(
remote_combine, workflow_id=self.workflow_id
) # this materializes the object locally
return result
8 changes: 4 additions & 4 deletions hamilton/function_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def value(literal_value: Any) -> LiteralDependency:
"""Specifies that a parameterized dependency comes from a "literal" source.
E.G. value("foo") means that the value is actually the string value "foo"

@param literal_value: Python literal value to use
@return: A LiteralDependency object -- a signifier to the internal framework of the dependency type
:param literal_value: Python literal value to use
:return: A LiteralDependency object -- a signifier to the internal framework of the dependency type
"""
if isinstance(literal_value, LiteralDependency):
return literal_value
Expand All @@ -85,8 +85,8 @@ def source(dependency_on: Any) -> UpstreamDependency:
This means that it comes from a node somewhere else.
E.G. source("foo") means that it should be assigned the value that "foo" outputs.

@param dependency_on: Upstream node to come from
@return:An UpstreamDependency object -- a signifier to the internal framework of the dependency type.
:param dependency_on: Upstream node to come from
:return:An UpstreamDependency object -- a signifier to the internal framework of the dependency type.
"""
if isinstance(dependency_on, UpstreamDependency):
return dependency_on
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def load_requirements():
"dask-dataframe": ["dask[dataframe]"],
"dask-diagnostics": ["dask[diagnostics]"],
"dask-distributed": ["dask[distributed]"],
"ray": ["ray", "pyarrow"],
"ray": ["ray>=2.0.0", "pyarrow"],
"pyspark": ["pyspark[pandas_on_spark]"],
"pandera": ["pandera"],
},
Expand Down