Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support Arrow zerocopy serialization in object store #35110

Merged
merged 6 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions java/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ define_java_module(
"@maven//:commons_io_commons_io",
"@maven//:de_ruedigermoeller_fst",
"@maven//:net_java_dev_jna_jna",
"@maven//:org_apache_arrow_arrow_memory_core",
Deegue marked this conversation as resolved.
Show resolved Hide resolved
"@maven//:org_apache_arrow_arrow_memory_unsafe",
"@maven//:org_apache_arrow_arrow_vector",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_logging_log4j_log4j_api",
"@maven//:org_apache_logging_log4j_log4j_core",
Expand All @@ -117,6 +120,7 @@ define_java_module(
"@maven//:com_sun_xml_bind_jaxb_impl",
"@maven//:commons_io_commons_io",
"@maven//:javax_xml_bind_jaxb_api",
"@maven//:org_apache_arrow_arrow_vector",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_apache_logging_log4j_log4j_api",
"@maven//:org_apache_logging_log4j_log4j_core",
Expand Down
3 changes: 3 additions & 0 deletions java/dependencies.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def gen_java_deps():
"org.slf4j:slf4j-api:1.7.25",
"com.lmax:disruptor:3.3.4",
"net.java.dev.jna:jna:5.8.0",
"org.apache.arrow:arrow-memory-core:5.0.0",
"org.apache.arrow:arrow-memory-unsafe:5.0.0",
"org.apache.arrow:arrow-vector:5.0.0",
"org.apache.httpcomponents.client5:httpclient5:5.0.3",
"org.apache.httpcomponents.core5:httpcore5:5.0.2",
"org.apache.httpcomponents.client5:httpclient5-fluent:5.0.3",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import io.ray.runtime.generated.Common.ErrorType;
import io.ray.runtime.serializer.RayExceptionSerializer;
import io.ray.runtime.serializer.Serializer;
import io.ray.runtime.util.ArrowUtil;
import io.ray.runtime.util.IdUtil;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.commons.lang3.tuple.Pair;

/**
Expand Down Expand Up @@ -49,6 +51,7 @@ public class ObjectSerializer {
private static final byte[] TASK_EXECUTION_EXCEPTION_META =
String.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();

public static final byte[] OBJECT_METADATA_TYPE_ARROW = "ARROW".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
Expand Down Expand Up @@ -80,7 +83,9 @@ public static Object deserialize(

if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_RAW) == 0) {
if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_ARROW) == 0) {
return ArrowUtil.deserialize(data);
} else if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_RAW) == 0) {
if (objectType == ByteBuffer.class) {
return ByteBuffer.wrap(data);
}
Expand Down Expand Up @@ -136,6 +141,10 @@ public static NativeRayObject serialize(Object object) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof VectorSchemaRoot) {
// serialize arrow data using IPC Stream format
byte[] bytes = ArrowUtil.serialize((VectorSchemaRoot) object);
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_ARROW);
} else if (object instanceof ByteBuffer) {
// Serialize ByteBuffer to raw bytes.
ByteBuffer buffer = (ByteBuffer) object;
Expand Down
61 changes: 61 additions & 0 deletions java/runtime/src/main/java/io/ray/runtime/util/ArrowUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package io.ray.runtime.util;

import io.ray.api.exception.RayException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.nio.channels.Channels;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageChannelReader;
import org.apache.arrow.vector.ipc.message.MessageResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;

/** Helper method for serialize and deserialize arrow data. */
public class ArrowUtil {
public static final RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE);

/**
* Deserialize data in byte array to arrow data format.
*
* @return The vector schema root of arrow.
*/
public static VectorSchemaRoot deserialize(byte[] data) {
try (MessageChannelReader reader =
new MessageChannelReader(
new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data))), rootAllocator)) {
MessageResult result = reader.readNext();
Schema schema = MessageSerializer.deserializeSchema(result.getMessage());
VectorSchemaRoot root = VectorSchemaRoot.create(schema, rootAllocator);
VectorLoader loader = new VectorLoader(root);
result = reader.readNext();
ArrowRecordBatch batch =
MessageSerializer.deserializeRecordBatch(result.getMessage(), result.getBodyBuffer());
loader.load(batch);
return root;
} catch (Exception e) {
throw new RayException("Failed to deserialize Arrow data", e.getCause());
}
}

/**
* Serialize data from arrow data format to byte array.
*
* @return The byte array of data.
*/
public static byte[] serialize(VectorSchemaRoot root) {
try (ByteArrayOutputStream sink = new ByteArrayOutputStream();
ArrowStreamWriter writer = new ArrowStreamWriter(root, null, sink)) {
writer.start();
writer.writeBatch();
writer.end();
return sink.toByteArray();
} catch (Exception e) {
throw new RayException("Failed to serialize Arrow data", e.getCause());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package io.ray.test;

import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.function.PyFunction;
import io.ray.runtime.util.ArrowUtil;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.List;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

@Test(groups = {"cluster"})
public class CrossLanguageObjectStoreTest extends BaseTest {

private static final String PYTHON_MODULE = "test_cross_language_invocation";
private static final int vecSize = 5;

@BeforeClass
public void beforeClass() {
// Delete and re-create the temp dir.
File tempDir =
new File(
System.getProperty("java.io.tmpdir")
+ File.separator
+ "ray_cross_language_object_store_test");
FileUtils.deleteQuietly(tempDir);
tempDir.mkdirs();
tempDir.deleteOnExit();

// Write the test Python file to the temp dir.
InputStream in =
CrossLanguageObjectStoreTest.class.getResourceAsStream(
File.separator + PYTHON_MODULE + ".py");
File pythonFile = new File(tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py");
try {
FileUtils.copyInputStreamToFile(in, pythonFile);
} catch (IOException e) {
throw new RuntimeException(e);
}

System.setProperty(
"ray.job.code-search-path",
System.getProperty("java.class.path") + File.pathSeparator + tempDir.getAbsolutePath());
}

@Test
public void testPythonPutAndJavaGet() {
ObjectRef<VectorSchemaRoot> res =
Ray.task(PyFunction.of(PYTHON_MODULE, "py_put_into_object_store", VectorSchemaRoot.class))
.remote();
VectorSchemaRoot root = res.get();
BigIntVector newVector = (BigIntVector) root.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}

@Test
public void testJavaPutAndPythonGet() {
BigIntVector vector = new BigIntVector("ArrowBigIntVector", ArrowUtil.rootAllocator);
vector.setValueCount(vecSize);
for (int i = 0; i < vecSize; i++) {
vector.setSafe(i, i);
}
List<Field> fields = Arrays.asList(vector.getField());
List<FieldVector> vectors = Arrays.asList(vector);
VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors);
ObjectRef<VectorSchemaRoot> obj = Ray.put(root);

ObjectRef<VectorSchemaRoot> res =
Ray.task(
PyFunction.of(
PYTHON_MODULE, "py_object_store_get_and_check", VectorSchemaRoot.class),
obj)
.remote();

VectorSchemaRoot newRoot = res.get();
BigIntVector newVector = (BigIntVector) newRoot.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}
}
25 changes: 25 additions & 0 deletions java/test/src/main/java/io/ray/test/ObjectStoreTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import io.ray.api.Ray;
import io.ray.api.exception.RayTaskException;
import io.ray.api.exception.UnreconstructableException;
import io.ray.runtime.util.ArrowUtil;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.testng.Assert;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -75,6 +81,25 @@ public void testGetMultipleObjects() {
Assert.assertEquals(ints, Ray.get(refs));
}

@Test
public void testArrowObjects() {
final int vecSize = 10;
IntVector vector = new IntVector("ArrowIntVector", ArrowUtil.rootAllocator);
vector.setValueCount(vecSize);
for (int i = 0; i < vecSize; i++) {
vector.setSafe(i, i);
}
List<Field> fields = Arrays.asList(vector.getField());
List<FieldVector> vectors = Arrays.asList(vector);
VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors);
ObjectRef<VectorSchemaRoot> obj = Ray.put(root);
VectorSchemaRoot newRoot = obj.get();
IntVector newVector = (IntVector) newRoot.getVector(0);
for (int i = 0; i < vecSize; i++) {
Assert.assertEquals(i, newVector.get(i));
}
}

@Test(groups = {"cluster"})
public void testOwnerAssignWhenPut() throws Exception {
// This test should align with test_owner_assign_when_put in Python
Expand Down
34 changes: 34 additions & 0 deletions java/test/src/main/resources/test_cross_language_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import asyncio

import pyarrow as pa

import ray


Expand Down Expand Up @@ -189,3 +191,35 @@ def py_func_call_java_overloaded_method():
result = ray.get([ref1, ref2])
assert result == ["first", "firstsecond"]
return True


@ray.remote
def py_put_into_object_store():
column_values = [0, 1, 2, 3, 4]
column_array = pa.array(column_values)
table = pa.Table.from_arrays([column_array], names=["ArrowBigIntVector"])
return table


@ray.remote
def py_object_store_get_and_check(table):
column_values = [0, 1, 2, 3, 4]
column_array = pa.array(column_values)
expected_table = pa.Table.from_arrays([column_array], names=["ArrowBigIntVector"])

for column_name in table.column_names:
column1 = table[column_name]
column2 = expected_table[column_name]

indices = pa.compute.equal(column1, column2).to_pylist()
differing_rows = [i for i, index in enumerate(indices) if not index]

if differing_rows:
print(f"Differences in column '{column_name}':")
for row in differing_rows:
value1 = column1[row].as_py()
value2 = column2[row].as_py()
print(f"Row {row}: {value1} != {value2}")
raise RuntimeError("Check failed, two tables are not equal!")

return table
2 changes: 2 additions & 0 deletions python/ray/_private/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def env_set_by_user(key):
OBJECT_METADATA_TYPE_PYTHON = b"PYTHON"
# A constant used as object metadata to indicate the object is raw bytes.
OBJECT_METADATA_TYPE_RAW = b"RAW"
# A constant used as object metadata to indicate the object is arrow data.
OBJECT_METADATA_TYPE_ARROW = b"ARROW"

# A constant used as object metadata to indicate the object is an actor handle.
# This value should be synchronized with the Java definition in
Expand Down
23 changes: 21 additions & 2 deletions python/ray/_private/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray.cloudpickle as pickle
from ray._private import ray_constants
from ray._raylet import (
ArrowSerializedObject,
MessagePackSerializedObject,
MessagePackSerializer,
ObjectRefGenerator,
Expand Down Expand Up @@ -47,6 +48,11 @@
from ray.util import serialization_addons
from ray.util import inspect_serializability

try:
import pyarrow as pa
except ImportError:
pa = None

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -270,6 +276,12 @@ def _deserialize_object(self, data, metadata, object_ref):
if data is None:
return b""
return data.to_pybytes()
elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW:
assert (
pa is not None
), "pyarrow should be imported while deserializing arrow objects"
reader = pa.BufferReader(data)
return pa.ipc.open_stream(reader).read_all()
elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE:
obj = self._deserialize_msgpack_data(data, metadata_fields)
return _actor_handle_deserializer(obj)
Expand Down Expand Up @@ -461,5 +473,12 @@ def serialize(self, value):
# use a special metadata to indicate it's raw binary. So
# that this object can also be read by Java.
return RawSerializedObject(value)
else:
return self._serialize_to_msgpack(value)

# Check whether arrow is installed. If so, use Arrow IPC format
# to serialize this object, then it can also be read by Java.
if pa is not None and (
isinstance(value, pa.Table) or isinstance(value, pa.RecordBatch)
):
return ArrowSerializedObject(value)

return self._serialize_to_msgpack(value)
4 changes: 3 additions & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def combine_chunks(table: "pyarrow.Table") -> "pyarrow.Table":
cols = table.columns
new_cols = []
for col in cols:
if _is_column_extension_type(col):
if col.num_chunks == 0:
arr = pyarrow.chunked_array([], type=col.type)
elif _is_column_extension_type(col):
# Extension arrays don't support concatenation.
arr = _concatenate_extension_column(col)
else:
Expand Down
Loading