Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "[Core] Support Arrow zerocopy serialization in object… #36153

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
95986a4
Revert "Revert "[Core] Support Arrow zerocopy serialization in object…
Deegue Jun 7, 2023
015b2e5
test
Deegue Jun 13, 2023
d6d3794
test
Deegue Jun 15, 2023
709ad05
test
Deegue Jun 16, 2023
95f92d0
test
Deegue Jun 20, 2023
bd9c9af
test
Deegue Jun 21, 2023
f00591a
revert test log
Deegue Jun 28, 2023
9097771
Merge branch 'ray-project:master' into raydp_ser_revert
Deegue Jun 28, 2023
ad28538
move import inside
Deegue Jul 7, 2023
e761683
Merge branch 'raydp_ser_revert' of https://github.com/Deegue/ray into…
Deegue Jul 7, 2023
c3e5db3
nit
Deegue Jul 7, 2023
03e64d6
lint and fix test
Deegue Jul 7, 2023
dd10dab
Merge branch 'ray-project:master' into raydp_ser_revert
Deegue Jul 7, 2023
ede5fae
lint
Deegue Jul 7, 2023
e57a8de
move outside
Deegue Jul 10, 2023
74c85f0
nit
Deegue Jul 10, 2023
b96edbf
nit
Deegue Jul 10, 2023
119e0e4
Merge branch 'ray-project:master' into raydp_ser_revert
Deegue Jul 11, 2023
3fa10c3
test
Deegue Jul 11, 2023
aa456f6
Merge branch 'raydp_ser_revert' of https://github.com/Deegue/ray into…
Deegue Jul 11, 2023
e390b6a
lint
Deegue Jul 12, 2023
c02b243
move inside
Deegue Jul 20, 2023
e6ae146
nit
Deegue Jul 20, 2023
6aa4ad8
revert import in transform_pyarrow
kira-lin Jul 21, 2023
c47572b
comment added part and see if ci pass
kira-lin Jul 24, 2023
3da0729
see if import breaks ci
kira-lin Jul 25, 2023
5872b76
add back deser part
kira-lin Jul 25, 2023
5def2f9
add back serialize import
kira-lin Jul 25, 2023
371224d
return table if indices is empty
kira-lin Jul 25, 2023
94ed413
add empty check for take_table
kira-lin Jul 26, 2023
7bd21d3
Merge remote-tracking branch 'upstream/master' into raydp_ser_revert
kira-lin Jul 26, 2023
b53fc29
Merge branch 'ray-project:master' into raydp_ser_revert
Deegue Aug 1, 2023
9b8768d
Merge branch 'ray-project:master' into raydp_ser_revert
Deegue Aug 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions java/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ define_java_module(
"@maven//:commons_io_commons_io",
"@maven//:de_ruedigermoeller_fst",
"@maven//:net_java_dev_jna_jna",
"@maven//:org_apache_arrow_arrow_memory_core",
"@maven//:org_apache_arrow_arrow_memory_unsafe",
"@maven//:org_apache_arrow_arrow_vector",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_logging_log4j_log4j_api",
"@maven//:org_apache_logging_log4j_log4j_core",
Expand All @@ -117,6 +120,7 @@ define_java_module(
"@maven//:com_sun_xml_bind_jaxb_impl",
"@maven//:commons_io_commons_io",
"@maven//:javax_xml_bind_jaxb_api",
"@maven//:org_apache_arrow_arrow_vector",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_logging_log4j_log4j_api",
"@maven//:org_apache_logging_log4j_log4j_core",
Expand Down
3 changes: 3 additions & 0 deletions java/dependencies.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def gen_java_deps():
"org.slf4j:slf4j-api:1.7.25",
"com.lmax:disruptor:3.3.4",
"net.java.dev.jna:jna:5.8.0",
"org.apache.arrow:arrow-memory-core:5.0.0",
"org.apache.arrow:arrow-memory-unsafe:5.0.0",
"org.apache.arrow:arrow-vector:5.0.0",
"org.apache.httpcomponents.client5:httpclient5:5.0.3",
"org.apache.httpcomponents.core5:httpcore5:5.0.2",
"org.apache.httpcomponents.client5:httpclient5-fluent:5.0.3",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import io.ray.runtime.generated.Common.ErrorType;
import io.ray.runtime.serializer.RayExceptionSerializer;
import io.ray.runtime.serializer.Serializer;
import io.ray.runtime.util.ArrowUtil;
import io.ray.runtime.util.IdUtil;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.commons.lang3.tuple.Pair;

/**
Expand Down Expand Up @@ -49,6 +51,7 @@ public class ObjectSerializer {
private static final byte[] TASK_EXECUTION_EXCEPTION_META =
String.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();

public static final byte[] OBJECT_METADATA_TYPE_ARROW = "ARROW".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
Expand Down Expand Up @@ -80,7 +83,9 @@ public static Object deserialize(

if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_RAW) == 0) {
if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_ARROW) == 0) {
return ArrowUtil.deserialize(data);
} else if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_RAW) == 0) {
if (objectType == ByteBuffer.class) {
return ByteBuffer.wrap(data);
}
Expand Down Expand Up @@ -136,6 +141,10 @@ public static NativeRayObject serialize(Object object) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof VectorSchemaRoot) {
// serialize arrow data using IPC Stream format
byte[] bytes = ArrowUtil.serialize((VectorSchemaRoot) object);
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_ARROW);
} else if (object instanceof ByteBuffer) {
// Serialize ByteBuffer to raw bytes.
ByteBuffer buffer = (ByteBuffer) object;
Expand Down
61 changes: 61 additions & 0 deletions java/runtime/src/main/java/io/ray/runtime/util/ArrowUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package io.ray.runtime.util;

import io.ray.api.exception.RayException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.nio.channels.Channels;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageChannelReader;
import org.apache.arrow.vector.ipc.message.MessageResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;

/** Helper method for serialize and deserialize arrow data. */
public class ArrowUtil {
public static final RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE);

/**
* Deserialize data in byte array to arrow data format.
*
* @return The vector schema root of arrow.
*/
public static VectorSchemaRoot deserialize(byte[] data) {
try (MessageChannelReader reader =
new MessageChannelReader(
new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data))), rootAllocator)) {
MessageResult result = reader.readNext();
Schema schema = MessageSerializer.deserializeSchema(result.getMessage());
VectorSchemaRoot root = VectorSchemaRoot.create(schema, rootAllocator);
VectorLoader loader = new VectorLoader(root);
result = reader.readNext();
ArrowRecordBatch batch =
MessageSerializer.deserializeRecordBatch(result.getMessage(), result.getBodyBuffer());
loader.load(batch);
return root;
} catch (Exception e) {
throw new RayException("Failed to deserialize Arrow data", e.getCause());
}
}

/**
* Serialize data from arrow data format to byte array.
*
* @return The byte array of data.
*/
public static byte[] serialize(VectorSchemaRoot root) {
try (ByteArrayOutputStream sink = new ByteArrayOutputStream();
ArrowStreamWriter writer = new ArrowStreamWriter(root, null, sink)) {
writer.start();
writer.writeBatch();
writer.end();
return sink.toByteArray();
} catch (Exception e) {
throw new RayException("Failed to serialize Arrow data", e.getCause());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package io.ray.test;

import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.function.PyFunction;
import io.ray.runtime.util.ArrowUtil;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

@Test(groups = {"cluster"})
public class CrossLanguageObjectStoreTest extends BaseTest {

private static final String PYTHON_MODULE = "test_cross_language_invocation";
private static final int vecSize = 5;

@BeforeClass
public void beforeClass() {
// Delete and re-create the temp dir.
File tempDir =
new File(
System.getProperty("java.io.tmpdir")
+ File.separator
+ "ray_cross_language_object_store_test");
FileUtils.deleteQuietly(tempDir);
tempDir.mkdirs();
tempDir.deleteOnExit();

// Write the test Python file to the temp dir.
InputStream in =
CrossLanguageObjectStoreTest.class.getResourceAsStream(
File.separator + PYTHON_MODULE + ".py");
File pythonFile = new File(tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py");
try {
FileUtils.copyInputStreamToFile(in, pythonFile);
} catch (IOException e) {
throw new RuntimeException(e);
}

System.setProperty(
"ray.job.code-search-path",
System.getProperty("java.class.path") + File.pathSeparator + tempDir.getAbsolutePath());
}

@Test
public void testPythonPutAndJavaGet() {
ObjectRef<VectorSchemaRoot> res =
Ray.task(PyFunction.of(PYTHON_MODULE, "py_put_into_object_store", VectorSchemaRoot.class))
.remote();
VectorSchemaRoot root = res.get();
BigIntVector newVector = (BigIntVector) root.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}

@Test
public void testJavaPutAndPythonGet() {
BigIntVector vector = new BigIntVector("ArrowBigIntVector", ArrowUtil.rootAllocator);
vector.setValueCount(vecSize);
for (int i = 0; i < vecSize; i++) {
vector.setSafe(i, i);
}
List<Field> fields = Arrays.asList(vector.getField());
List<FieldVector> vectors = Arrays.asList(vector);
VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors);
ObjectRef<VectorSchemaRoot> obj = Ray.put(root);

ObjectRef<VectorSchemaRoot> res =
Ray.task(
PyFunction.of(
PYTHON_MODULE, "py_object_store_get_and_check", VectorSchemaRoot.class),
obj)
.remote();

VectorSchemaRoot newRoot = res.get();
BigIntVector newVector = (BigIntVector) newRoot.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}
}
25 changes: 25 additions & 0 deletions java/test/src/main/java/io/ray/test/ObjectStoreTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import io.ray.api.Ray;
import io.ray.api.exception.RayTaskException;
import io.ray.api.exception.UnreconstructableException;
import io.ray.runtime.util.ArrowUtil;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.testng.Assert;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -75,6 +81,25 @@ public void testGetMultipleObjects() {
Assert.assertEquals(ints, Ray.get(refs));
}

@Test
public void testArrowObjects() {
final int vecSize = 10;
IntVector vector = new IntVector("ArrowIntVector", ArrowUtil.rootAllocator);
vector.setValueCount(vecSize);
for (int i = 0; i < vecSize; i++) {
vector.setSafe(i, i);
}
List<Field> fields = Arrays.asList(vector.getField());
List<FieldVector> vectors = Arrays.asList(vector);
VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors);
ObjectRef<VectorSchemaRoot> obj = Ray.put(root);
VectorSchemaRoot newRoot = obj.get();
IntVector newVector = (IntVector) newRoot.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}

@Test(groups = {"cluster"})
public void testOwnerAssignWhenPut() throws Exception {
// This test should align with test_owner_assign_when_put in Python
Expand Down
24 changes: 24 additions & 0 deletions java/test/src/main/resources/test_cross_language_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import asyncio

import pyarrow as pa

import ray


Expand Down Expand Up @@ -189,3 +191,25 @@ def py_func_call_java_overloaded_method():
result = ray.get([ref1, ref2])
assert result == ["first", "firstsecond"]
return True


@ray.remote
def py_put_into_object_store():
column_values = [0, 1, 2, 3, 4]
column_array = pa.array(column_values)
table = pa.Table.from_arrays([column_array], names=["ArrowBigIntVector"])
return table


@ray.remote
def py_object_store_get_and_check(table):
column_values = [0, 1, 2, 3, 4]
column_array = pa.array(column_values)
expected_table = pa.Table.from_arrays([column_array], names=["ArrowBigIntVector"])

for column_name in table.column_names:
column1 = table[column_name]
column2 = expected_table[column_name]
assert column1.equals(column2)

return table
2 changes: 2 additions & 0 deletions python/ray/_private/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def env_set_by_user(key):
OBJECT_METADATA_TYPE_PYTHON = b"PYTHON"
# A constant used as object metadata to indicate the object is raw bytes.
OBJECT_METADATA_TYPE_RAW = b"RAW"
# A constant used as object metadata to indicate the object is arrow data.
OBJECT_METADATA_TYPE_ARROW = b"ARROW"

# A constant used as object metadata to indicate the object is an actor handle.
# This value should be synchronized with the Java definition in
Expand Down
25 changes: 23 additions & 2 deletions python/ray/_private/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray.cloudpickle as pickle
from ray._private import ray_constants
from ray._raylet import (
ArrowSerializedObject,
MessagePackSerializedObject,
MessagePackSerializer,
ObjectRefGenerator,
Expand Down Expand Up @@ -271,6 +272,14 @@ def _deserialize_object(self, data, metadata, object_ref):
if data is None:
return b""
return data.to_pybytes()
elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW:
try:
import pyarrow
except ImportError:
pyarrow = None

reader = pyarrow.BufferReader(data)
return pyarrow.ipc.open_stream(reader).read_all()
elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE:
obj = self._deserialize_msgpack_data(data, metadata_fields)
return _actor_handle_deserializer(obj)
Expand Down Expand Up @@ -464,5 +473,17 @@ def serialize(self, value):
# use a special metadata to indicate it's raw binary. So
# that this object can also be read by Java.
return RawSerializedObject(value)
else:
return self._serialize_to_msgpack(value)

try:
import pyarrow
except ImportError:
pyarrow = None

# Check whether arrow is installed. If so, use Arrow IPC format
# to serialize this object, then it can also be read by Java.
if pyarrow is not None and (
isinstance(value, pyarrow.Table) or isinstance(value, pyarrow.RecordBatch)
):
return ArrowSerializedObject(value)

return self._serialize_to_msgpack(value)
9 changes: 8 additions & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ def take_table(
extension arrays. This is exposed as a static method for easier use on
intermediate tables, not underlying an ArrowBlockAccessor.
"""

from ray.air.util.transform_pyarrow import (
_concatenate_extension_column,
_is_column_extension_type,
)

if table.num_rows == 0:
return table
if len(indices) == 0:
return pyarrow.Table.from_pydict({})
if any(_is_column_extension_type(col) for col in table.columns):
new_cols = []
for col in table.columns:
Expand Down Expand Up @@ -279,7 +284,9 @@ def combine_chunks(table: "pyarrow.Table") -> "pyarrow.Table":
cols = table.columns
new_cols = []
for col in cols:
if _is_column_extension_type(col):
if col.num_chunks == 0:
arr = pyarrow.chunked_array([], type=col.type)
elif _is_column_extension_type(col):
# Extension arrays don't support concatenation.
arr = _concatenate_extension_column(col)
else:
Expand Down
Loading
Loading