From 95986a434a7583fc04d1ebff26df29c9a0023186 Mon Sep 17 00:00:00 2001 From: Deegue Date: Wed, 7 Jun 2023 08:13:14 +0000 Subject: [PATCH 01/25] Revert "Revert "[Core] Support Arrow zerocopy serialization in object store (#35110)" (#36000)" This reverts commit 822904b312423010682fa92f4085c05e08819337. --- java/BUILD.bazel | 4 + java/dependencies.bzl | 3 + .../ray/runtime/object/ObjectSerializer.java | 11 ++- .../java/io/ray/runtime/util/ArrowUtil.java | 61 ++++++++++++ .../test/CrossLanguageObjectStoreTest.java | 92 +++++++++++++++++++ .../java/io/ray/test/ObjectStoreTest.java | 25 +++++ .../test_cross_language_invocation.py | 34 +++++++ python/ray/_private/ray_constants.py | 2 + python/ray/_private/serialization.py | 23 ++++- .../_internal/arrow_ops/transform_pyarrow.py | 4 +- python/ray/includes/serialization.pxi | 34 +++++++ 11 files changed, 289 insertions(+), 4 deletions(-) create mode 100644 java/runtime/src/main/java/io/ray/runtime/util/ArrowUtil.java create mode 100644 java/test/src/main/java/io/ray/test/CrossLanguageObjectStoreTest.java diff --git a/java/BUILD.bazel b/java/BUILD.bazel index a7f2497695b4..4560f5ec56af 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -91,6 +91,9 @@ define_java_module( "@maven//:commons_io_commons_io", "@maven//:de_ruedigermoeller_fst", "@maven//:net_java_dev_jna_jna", + "@maven//:org_apache_arrow_arrow_memory_core", + "@maven//:org_apache_arrow_arrow_memory_unsafe", + "@maven//:org_apache_arrow_arrow_vector", "@maven//:org_apache_commons_commons_lang3", "@maven//:org_apache_logging_log4j_log4j_api", "@maven//:org_apache_logging_log4j_log4j_core", @@ -117,6 +120,7 @@ define_java_module( "@maven//:com_sun_xml_bind_jaxb_impl", "@maven//:commons_io_commons_io", "@maven//:javax_xml_bind_jaxb_api", + "@maven//:org_apache_arrow_arrow_vector", "@maven//:org_apache_commons_commons_lang3", "@maven//:org_apache_logging_log4j_log4j_api", "@maven//:org_apache_logging_log4j_log4j_core", diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 5a64a0326383..a65ee8d4d87b 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -27,6 +27,9 @@ def gen_java_deps(): "org.slf4j:slf4j-api:1.7.25", "com.lmax:disruptor:3.3.4", "net.java.dev.jna:jna:5.8.0", + "org.apache.arrow:arrow-memory-core:5.0.0", + "org.apache.arrow:arrow-memory-unsafe:5.0.0", + "org.apache.arrow:arrow-vector:5.0.0", "org.apache.httpcomponents.client5:httpclient5:5.0.3", "org.apache.httpcomponents.core5:httpcore5:5.0.2", "org.apache.httpcomponents.client5:httpclient5-fluent:5.0.3", diff --git a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java index a516f415c3b4..7da905e1af5f 100644 --- a/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java +++ b/java/runtime/src/main/java/io/ray/runtime/object/ObjectSerializer.java @@ -14,6 +14,7 @@ import io.ray.runtime.generated.Common.ErrorType; import io.ray.runtime.serializer.RayExceptionSerializer; import io.ray.runtime.serializer.Serializer; +import io.ray.runtime.util.ArrowUtil; import io.ray.runtime.util.IdUtil; import java.nio.ByteBuffer; import java.util.ArrayList; @@ -21,6 +22,7 @@ import java.util.HashSet; import java.util.List; import java.util.Set; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.commons.lang3.tuple.Pair; /** @@ -49,6 +51,7 @@ public class ObjectSerializer { private static final byte[] TASK_EXECUTION_EXCEPTION_META = String.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes(); + public static final byte[] OBJECT_METADATA_TYPE_ARROW = "ARROW".getBytes(); public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes(); public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes(); public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes(); @@ -80,7 +83,9 @@ public static Object deserialize( if (meta != null && meta.length > 0) { // If meta is not null, deserialize the object from meta. - if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_RAW) == 0) { + if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_ARROW) == 0) { + return ArrowUtil.deserialize(data); + } else if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_RAW) == 0) { if (objectType == ByteBuffer.class) { return ByteBuffer.wrap(data); } @@ -136,6 +141,10 @@ public static NativeRayObject serialize(Object object) { // If the object is a byte array, skip serializing it and use a special metadata to // indicate it's raw binary. So that this object can also be read by Python. return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW); + } else if (object instanceof VectorSchemaRoot) { + // serialize arrow data using IPC Stream format + byte[] bytes = ArrowUtil.serialize((VectorSchemaRoot) object); + return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_ARROW); } else if (object instanceof ByteBuffer) { // Serialize ByteBuffer to raw bytes. ByteBuffer buffer = (ByteBuffer) object; diff --git a/java/runtime/src/main/java/io/ray/runtime/util/ArrowUtil.java b/java/runtime/src/main/java/io/ray/runtime/util/ArrowUtil.java new file mode 100644 index 000000000000..19286f2a6c40 --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/util/ArrowUtil.java @@ -0,0 +1,61 @@ +package io.ray.runtime.util; + +import io.ray.api.exception.RayException; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.nio.channels.Channels; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.MessageChannelReader; +import org.apache.arrow.vector.ipc.message.MessageResult; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Helper method for serialize and deserialize arrow data. */ +public class ArrowUtil { + public static final RootAllocator rootAllocator = new RootAllocator(Long.MAX_VALUE); + + /** + * Deserialize data in byte array to arrow data format. + * + * @return The vector schema root of arrow. + */ + public static VectorSchemaRoot deserialize(byte[] data) { + try (MessageChannelReader reader = + new MessageChannelReader( + new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data))), rootAllocator)) { + MessageResult result = reader.readNext(); + Schema schema = MessageSerializer.deserializeSchema(result.getMessage()); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, rootAllocator); + VectorLoader loader = new VectorLoader(root); + result = reader.readNext(); + ArrowRecordBatch batch = + MessageSerializer.deserializeRecordBatch(result.getMessage(), result.getBodyBuffer()); + loader.load(batch); + return root; + } catch (Exception e) { + throw new RayException("Failed to deserialize Arrow data", e.getCause()); + } + } + + /** + * Serialize data from arrow data format to byte array. + * + * @return The byte array of data. + */ + public static byte[] serialize(VectorSchemaRoot root) { + try (ByteArrayOutputStream sink = new ByteArrayOutputStream(); + ArrowStreamWriter writer = new ArrowStreamWriter(root, null, sink)) { + writer.start(); + writer.writeBatch(); + writer.end(); + return sink.toByteArray(); + } catch (Exception e) { + throw new RayException("Failed to serialize Arrow data", e.getCause()); + } + } +} diff --git a/java/test/src/main/java/io/ray/test/CrossLanguageObjectStoreTest.java b/java/test/src/main/java/io/ray/test/CrossLanguageObjectStoreTest.java new file mode 100644 index 000000000000..c5665346bcb4 --- /dev/null +++ b/java/test/src/main/java/io/ray/test/CrossLanguageObjectStoreTest.java @@ -0,0 +1,92 @@ +package io.ray.test; + +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.api.function.PyFunction; +import io.ray.runtime.util.ArrowUtil; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.commons.io.FileUtils; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +@Test(groups = {"cluster"}) +public class CrossLanguageObjectStoreTest extends BaseTest { + + private static final String PYTHON_MODULE = "test_cross_language_invocation"; + private static final int vecSize = 5; + + @BeforeClass + public void beforeClass() { + // Delete and re-create the temp dir. + File tempDir = + new File( + System.getProperty("java.io.tmpdir") + + File.separator + + "ray_cross_language_object_store_test"); + FileUtils.deleteQuietly(tempDir); + tempDir.mkdirs(); + tempDir.deleteOnExit(); + + // Write the test Python file to the temp dir. + InputStream in = + CrossLanguageObjectStoreTest.class.getResourceAsStream( + File.separator + PYTHON_MODULE + ".py"); + File pythonFile = new File(tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py"); + try { + FileUtils.copyInputStreamToFile(in, pythonFile); + } catch (IOException e) { + throw new RuntimeException(e); + } + + System.setProperty( + "ray.job.code-search-path", + System.getProperty("java.class.path") + File.pathSeparator + tempDir.getAbsolutePath()); + } + + @Test + public void testPythonPutAndJavaGet() { + ObjectRef res = + Ray.task(PyFunction.of(PYTHON_MODULE, "py_put_into_object_store", VectorSchemaRoot.class)) + .remote(); + VectorSchemaRoot root = res.get(); + BigIntVector newVector = (BigIntVector) root.getVector(0); + for (int i = 0; i < vecSize; i++) { + Assert.assertEquals(i, newVector.get(i)); + } + } + + @Test + public void testJavaPutAndPythonGet() { + BigIntVector vector = new BigIntVector("ArrowBigIntVector", ArrowUtil.rootAllocator); + vector.setValueCount(vecSize); + for (int i = 0; i < vecSize; i++) { + vector.setSafe(i, i); + } + List fields = Arrays.asList(vector.getField()); + List vectors = Arrays.asList(vector); + VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors); + ObjectRef obj = Ray.put(root); + + ObjectRef res = + Ray.task( + PyFunction.of( + PYTHON_MODULE, "py_object_store_get_and_check", VectorSchemaRoot.class), + obj) + .remote(); + + VectorSchemaRoot newRoot = res.get(); + BigIntVector newVector = (BigIntVector) newRoot.getVector(0); + for (int i = 0; i < vecSize; i++) { + Assert.assertEquals(i, newVector.get(i)); + } + } +} diff --git a/java/test/src/main/java/io/ray/test/ObjectStoreTest.java b/java/test/src/main/java/io/ray/test/ObjectStoreTest.java index 43955eedc98c..85a187789c41 100644 --- a/java/test/src/main/java/io/ray/test/ObjectStoreTest.java +++ b/java/test/src/main/java/io/ray/test/ObjectStoreTest.java @@ -6,8 +6,14 @@ import io.ray.api.Ray; import io.ray.api.exception.RayTaskException; import io.ray.api.exception.UnreconstructableException; +import io.ray.runtime.util.ArrowUtil; +import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Field; import org.testng.Assert; import org.testng.annotations.Test; @@ -75,6 +81,25 @@ public void testGetMultipleObjects() { Assert.assertEquals(ints, Ray.get(refs)); } + @Test + public void testArrowObjects() { + final int vecSize = 10; + IntVector vector = new IntVector("ArrowIntVector", ArrowUtil.rootAllocator); + vector.setValueCount(vecSize); + for (int i = 0; i < vecSize; i++) { + vector.setSafe(i, i); + } + List fields = Arrays.asList(vector.getField()); + List vectors = Arrays.asList(vector); + VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors); + ObjectRef obj = Ray.put(root); + VectorSchemaRoot newRoot = obj.get(); + IntVector newVector = (IntVector) newRoot.getVector(0); + for (int i = 0; i < vecSize; i++) { + Assert.assertEquals(i, newVector.get(i)); + } + } + @Test(groups = {"cluster"}) public void testOwnerAssignWhenPut() throws Exception { // This test should align with test_owner_assign_when_put in Python diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py index fc6de702ab09..4ea96498e8c1 100644 --- a/java/test/src/main/resources/test_cross_language_invocation.py +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -3,6 +3,8 @@ import asyncio +import pyarrow as pa + import ray @@ -189,3 +191,35 @@ def py_func_call_java_overloaded_method(): result = ray.get([ref1, ref2]) assert result == ["first", "firstsecond"] return True + + +@ray.remote +def py_put_into_object_store(): + column_values = [0, 1, 2, 3, 4] + column_array = pa.array(column_values) + table = pa.Table.from_arrays([column_array], names=["ArrowBigIntVector"]) + return table + + +@ray.remote +def py_object_store_get_and_check(table): + column_values = [0, 1, 2, 3, 4] + column_array = pa.array(column_values) + expected_table = pa.Table.from_arrays([column_array], names=["ArrowBigIntVector"]) + + for column_name in table.column_names: + column1 = table[column_name] + column2 = expected_table[column_name] + + indices = pa.compute.equal(column1, column2).to_pylist() + differing_rows = [i for i, index in enumerate(indices) if not index] + + if differing_rows: + print(f"Differences in column '{column_name}':") + for row in differing_rows: + value1 = column1[row].as_py() + value2 = column2[row].as_py() + print(f"Row {row}: {value1} != {value2}") + raise RuntimeError("Check failed, two tables are not equal!") + + return table diff --git a/python/ray/_private/ray_constants.py b/python/ray/_private/ray_constants.py index c274095b2750..3659ee0ce639 100644 --- a/python/ray/_private/ray_constants.py +++ b/python/ray/_private/ray_constants.py @@ -301,6 +301,8 @@ def env_set_by_user(key): OBJECT_METADATA_TYPE_PYTHON = b"PYTHON" # A constant used as object metadata to indicate the object is raw bytes. OBJECT_METADATA_TYPE_RAW = b"RAW" +# A constant used as object metadata to indicate the object is arrow data. +OBJECT_METADATA_TYPE_ARROW = b"ARROW" # A constant used as object metadata to indicate the object is an actor handle. # This value should be synchronized with the Java definition in diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index 64b2b1129fe3..7c604514fe72 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -8,6 +8,7 @@ import ray.cloudpickle as pickle from ray._private import ray_constants from ray._raylet import ( + ArrowSerializedObject, MessagePackSerializedObject, MessagePackSerializer, ObjectRefGenerator, @@ -47,6 +48,11 @@ from ray.util import serialization_addons from ray.util import inspect_serializability +try: + import pyarrow as pa +except ImportError: + pa = None + logger = logging.getLogger(__name__) @@ -270,6 +276,12 @@ def _deserialize_object(self, data, metadata, object_ref): if data is None: return b"" return data.to_pybytes() + elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW: + assert ( + pa is not None + ), "pyarrow should be imported while deserializing arrow objects" + reader = pa.BufferReader(data) + return pa.ipc.open_stream(reader).read_all() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE: obj = self._deserialize_msgpack_data(data, metadata_fields) return _actor_handle_deserializer(obj) @@ -461,5 +473,12 @@ def serialize(self, value): # use a special metadata to indicate it's raw binary. So # that this object can also be read by Java. return RawSerializedObject(value) - else: - return self._serialize_to_msgpack(value) + + # Check whether arrow is installed. If so, use Arrow IPC format + # to serialize this object, then it can also be read by Java. + if pa is not None and ( + isinstance(value, pa.Table) or isinstance(value, pa.RecordBatch) + ): + return ArrowSerializedObject(value) + + return self._serialize_to_msgpack(value) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 9726dc65c1ed..956f3d0386b6 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -277,7 +277,9 @@ def combine_chunks(table: "pyarrow.Table") -> "pyarrow.Table": cols = table.columns new_cols = [] for col in cols: - if _is_column_extension_type(col): + if col.num_chunks == 0: + arr = pyarrow.chunked_array([], type=col.type) + elif _is_column_extension_type(col): # Extension arrays don't support concatenation. arr = _concatenate_extension_column(col) else: diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index 42303482d5b2..c5ef0814927a 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -536,3 +536,37 @@ cdef class RawSerializedObject(SerializedObject): MEMCOPY_THREADS) else: memcpy(&buffer[0], self.value_ptr, self._total_bytes) + + +try: + import pyarrow as pa +except ImportError: + pa = None + +cdef class ArrowSerializedObject(SerializedObject): + cdef: + object value + int64_t _total_bytes + + def __init__(self, value): + super(ArrowSerializedObject, + self).__init__(ray_constants.OBJECT_METADATA_TYPE_ARROW) + self.value = value + sink = pa.MockOutputStream() + writer = pa.ipc.new_stream(sink, self.value.schema) + writer.write(self.value) + writer.close() + self._total_bytes = sink.size() + + @property + def total_bytes(self): + return self._total_bytes + + @cython.boundscheck(False) + @cython.wraparound(False) + cdef void write_to(self, uint8_t[:] buffer) nogil: + with gil: + sink = pa.FixedSizeBufferWriter(pa.py_buffer(buffer)) + writer = pa.ipc.new_stream(sink, self.value.schema) + writer.write(self.value) + writer.close() From 015b2e5ee1abb1c8f9215bdbd24214141afd8cbc Mon Sep 17 00:00:00 2001 From: Deegue Date: Tue, 13 Jun 2023 01:40:15 +0000 Subject: [PATCH 02/25] test --- python/ray/tests/test_advanced_9.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tests/test_advanced_9.py b/python/ray/tests/test_advanced_9.py index a4ba35d1756d..86bacd6d7504 100644 --- a/python/ray/tests/test_advanced_9.py +++ b/python/ray/tests/test_advanced_9.py @@ -417,7 +417,7 @@ def test_omp_threads_set_third_party(ray_start_cluster, monkeypatch): m.delenv("OMP_NUM_THREADS", raising=False) cluster = ray_start_cluster - cluster.add_node(num_cpus=4) + cluster.add_node(num_cpus=2) ray.init(address=cluster.address) @ray.remote(num_cpus=2) From d6d37941185d49a7157415b5c249f86f391ed701 Mon Sep 17 00:00:00 2001 From: Deegue Date: Thu, 15 Jun 2023 10:31:43 +0000 Subject: [PATCH 03/25] test --- python/ray/_private/utils.py | 7 +++++++ python/ray/tests/test_advanced_9.py | 8 +++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 4499dd1a3c76..336017688080 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -308,19 +308,25 @@ def set_omp_num_threads_if_unset() -> bool: """ num_threads_from_env = os.environ.get("OMP_NUM_THREADS") + + logger.warning("SSSSSS:1") if num_threads_from_env is not None: # No ops if it's set + logger.warning("SSSSSS:2," + str(num_threads_from_env)) return False # If unset, try setting the correct CPU count assigned. runtime_ctx = ray.get_runtime_context() if runtime_ctx.worker.mode != ray._private.worker.WORKER_MODE: # Non worker mode, no ops. + logger.warning("SSSSSS:3," + str(runtime_ctx.worker.mode)) return False num_assigned_cpus = runtime_ctx.get_assigned_resources().get("CPU") + logger.warning("SSSSSS:4," + str(num_assigned_cpus)) if num_assigned_cpus is None: + logger.warning("SSSSSS:5," + str(num_assigned_cpus)) # This is an actor task w/o any num_cpus specified, set it to 1 logger.debug( "[ray] Forcing OMP_NUM_THREADS=1 to avoid performance " @@ -335,6 +341,7 @@ def set_omp_num_threads_if_unset() -> bool: # For num_cpus >= 1: Set to the floor of the actual assigned cpus. omp_num_threads = max(math.floor(num_assigned_cpus), 1) os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) + logger.warning("SSSSSS:6," + str(omp_num_threads)) return True diff --git a/python/ray/tests/test_advanced_9.py b/python/ray/tests/test_advanced_9.py index 86bacd6d7504..ce0ac94d0428 100644 --- a/python/ray/tests/test_advanced_9.py +++ b/python/ray/tests/test_advanced_9.py @@ -416,10 +416,15 @@ def test_omp_threads_set_third_party(ray_start_cluster, monkeypatch): with monkeypatch.context() as m: m.delenv("OMP_NUM_THREADS", raising=False) + import logging + logger = logging.getLogger(__name__) + cluster = ray_start_cluster - cluster.add_node(num_cpus=2) + cluster.add_node(num_cpus=4) ray.init(address=cluster.address) + logger.warning("SSSSSS:10") + @ray.remote(num_cpus=2) def f(): # Assert numpy using 2 threads for it's parallelism backend. @@ -427,6 +432,7 @@ def f(): from threadpoolctl import threadpool_info for pool_info in threadpool_info(): + logger.warning("SSSSSS:11," + str(pool_info["num_threads"])) assert pool_info["num_threads"] == 2 import numexpr From 709ad056a7887ae37a600704a154b49d571802f4 Mon Sep 17 00:00:00 2001 From: Deegue Date: Fri, 16 Jun 2023 07:39:50 +0000 Subject: [PATCH 04/25] test --- python/ray/_private/utils.py | 5 +++++ python/ray/tests/test_advanced_9.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 336017688080..5c2e68155bd3 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -320,6 +320,11 @@ def set_omp_num_threads_if_unset() -> bool: if runtime_ctx.worker.mode != ray._private.worker.WORKER_MODE: # Non worker mode, no ops. logger.warning("SSSSSS:3," + str(runtime_ctx.worker.mode)) + import traceback + import io + buf = io.StringIO() + traceback.print_stack(file=buf) + logger.warning("SSSSSS:buf1:" + buf.getvalue()) return False num_assigned_cpus = runtime_ctx.get_assigned_resources().get("CPU") diff --git a/python/ray/tests/test_advanced_9.py b/python/ray/tests/test_advanced_9.py index ce0ac94d0428..a3464d53d550 100644 --- a/python/ray/tests/test_advanced_9.py +++ b/python/ray/tests/test_advanced_9.py @@ -433,6 +433,11 @@ def f(): for pool_info in threadpool_info(): logger.warning("SSSSSS:11," + str(pool_info["num_threads"])) + import traceback + import io + buf = io.StringIO() + traceback.print_stack(file=buf) + logger.warning("SSSSSS:buf2:" + buf.getvalue()) assert pool_info["num_threads"] == 2 import numexpr From 95f92d034b6e228091358256cfafe275fecedd6d Mon Sep 17 00:00:00 2001 From: Deegue Date: Tue, 20 Jun 2023 07:37:16 +0000 Subject: [PATCH 05/25] test --- python/ray/tests/test_advanced_9.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tests/test_advanced_9.py b/python/ray/tests/test_advanced_9.py index a3464d53d550..a9357ccc5dd6 100644 --- a/python/ray/tests/test_advanced_9.py +++ b/python/ray/tests/test_advanced_9.py @@ -425,7 +425,7 @@ def test_omp_threads_set_third_party(ray_start_cluster, monkeypatch): logger.warning("SSSSSS:10") - @ray.remote(num_cpus=2) + @ray.remote(num_cpus=1) def f(): # Assert numpy using 2 threads for it's parallelism backend. import numpy # noqa: F401 From bd9c9af5abe8c49170963953158ff19bb9018e67 Mon Sep 17 00:00:00 2001 From: Deegue Date: Wed, 21 Jun 2023 01:13:47 +0000 Subject: [PATCH 06/25] test --- python/ray/tests/test_advanced_9.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tests/test_advanced_9.py b/python/ray/tests/test_advanced_9.py index a9357ccc5dd6..360fcce4159e 100644 --- a/python/ray/tests/test_advanced_9.py +++ b/python/ray/tests/test_advanced_9.py @@ -420,7 +420,7 @@ def test_omp_threads_set_third_party(ray_start_cluster, monkeypatch): logger = logging.getLogger(__name__) cluster = ray_start_cluster - cluster.add_node(num_cpus=4) + cluster.add_node(num_cpus=1) ray.init(address=cluster.address) logger.warning("SSSSSS:10") From f00591ae8348d4401de77c688c09a77cbbb3ca6d Mon Sep 17 00:00:00 2001 From: Deegue Date: Wed, 28 Jun 2023 02:13:46 +0000 Subject: [PATCH 07/25] revert test log --- python/ray/_private/utils.py | 11 ----------- python/ray/tests/test_advanced_9.py | 13 +------------ 2 files changed, 1 insertion(+), 23 deletions(-) diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 5c2e68155bd3..cd287188310a 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -309,29 +309,19 @@ def set_omp_num_threads_if_unset() -> bool: """ num_threads_from_env = os.environ.get("OMP_NUM_THREADS") - logger.warning("SSSSSS:1") if num_threads_from_env is not None: # No ops if it's set - logger.warning("SSSSSS:2," + str(num_threads_from_env)) return False # If unset, try setting the correct CPU count assigned. runtime_ctx = ray.get_runtime_context() if runtime_ctx.worker.mode != ray._private.worker.WORKER_MODE: # Non worker mode, no ops. - logger.warning("SSSSSS:3," + str(runtime_ctx.worker.mode)) - import traceback - import io - buf = io.StringIO() - traceback.print_stack(file=buf) - logger.warning("SSSSSS:buf1:" + buf.getvalue()) return False num_assigned_cpus = runtime_ctx.get_assigned_resources().get("CPU") - logger.warning("SSSSSS:4," + str(num_assigned_cpus)) if num_assigned_cpus is None: - logger.warning("SSSSSS:5," + str(num_assigned_cpus)) # This is an actor task w/o any num_cpus specified, set it to 1 logger.debug( "[ray] Forcing OMP_NUM_THREADS=1 to avoid performance " @@ -346,7 +336,6 @@ def set_omp_num_threads_if_unset() -> bool: # For num_cpus >= 1: Set to the floor of the actual assigned cpus. omp_num_threads = max(math.floor(num_assigned_cpus), 1) os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) - logger.warning("SSSSSS:6," + str(omp_num_threads)) return True diff --git a/python/ray/tests/test_advanced_9.py b/python/ray/tests/test_advanced_9.py index 360fcce4159e..40cc1d9a4632 100644 --- a/python/ray/tests/test_advanced_9.py +++ b/python/ray/tests/test_advanced_9.py @@ -416,15 +416,10 @@ def test_omp_threads_set_third_party(ray_start_cluster, monkeypatch): with monkeypatch.context() as m: m.delenv("OMP_NUM_THREADS", raising=False) - import logging - logger = logging.getLogger(__name__) - cluster = ray_start_cluster - cluster.add_node(num_cpus=1) + cluster.add_node(num_cpus=4) ray.init(address=cluster.address) - logger.warning("SSSSSS:10") - @ray.remote(num_cpus=1) def f(): # Assert numpy using 2 threads for it's parallelism backend. @@ -432,12 +427,6 @@ def f(): from threadpoolctl import threadpool_info for pool_info in threadpool_info(): - logger.warning("SSSSSS:11," + str(pool_info["num_threads"])) - import traceback - import io - buf = io.StringIO() - traceback.print_stack(file=buf) - logger.warning("SSSSSS:buf2:" + buf.getvalue()) assert pool_info["num_threads"] == 2 import numexpr From ad28538fbd504bd4146ae443c144bdc423743883 Mon Sep 17 00:00:00 2001 From: Deegue Date: Fri, 7 Jul 2023 02:32:08 +0000 Subject: [PATCH 08/25] move import inside --- python/ray/_private/serialization.py | 14 ++++++-------- python/ray/_private/utils.py | 1 - python/ray/includes/serialization.pxi | 7 ++----- python/ray/tests/test_advanced_9.py | 2 +- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index 7c604514fe72..a543b3b2029e 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -48,11 +48,6 @@ from ray.util import serialization_addons from ray.util import inspect_serializability -try: - import pyarrow as pa -except ImportError: - pa = None - logger = logging.getLogger(__name__) @@ -277,9 +272,7 @@ def _deserialize_object(self, data, metadata, object_ref): return b"" return data.to_pybytes() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW: - assert ( - pa is not None - ), "pyarrow should be imported while deserializing arrow objects" + import pyarrow as pa reader = pa.BufferReader(data) return pa.ipc.open_stream(reader).read_all() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE: @@ -474,6 +467,11 @@ def serialize(self, value): # that this object can also be read by Java. return RawSerializedObject(value) + try: + import pyarrow as pa + except ImportError: + pa = None + # Check whether arrow is installed. If so, use Arrow IPC format # to serialize this object, then it can also be read by Java. if pa is not None and ( diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index cd287188310a..4499dd1a3c76 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -308,7 +308,6 @@ def set_omp_num_threads_if_unset() -> bool: """ num_threads_from_env = os.environ.get("OMP_NUM_THREADS") - if num_threads_from_env is not None: # No ops if it's set return False diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index c5ef0814927a..f3a46931e623 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -538,16 +538,13 @@ cdef class RawSerializedObject(SerializedObject): memcpy(&buffer[0], self.value_ptr, self._total_bytes) -try: - import pyarrow as pa -except ImportError: - pa = None - cdef class ArrowSerializedObject(SerializedObject): cdef: object value int64_t _total_bytes + import pyarrow as pa + def __init__(self, value): super(ArrowSerializedObject, self).__init__(ray_constants.OBJECT_METADATA_TYPE_ARROW) diff --git a/python/ray/tests/test_advanced_9.py b/python/ray/tests/test_advanced_9.py index 40cc1d9a4632..a4ba35d1756d 100644 --- a/python/ray/tests/test_advanced_9.py +++ b/python/ray/tests/test_advanced_9.py @@ -420,7 +420,7 @@ def test_omp_threads_set_third_party(ray_start_cluster, monkeypatch): cluster.add_node(num_cpus=4) ray.init(address=cluster.address) - @ray.remote(num_cpus=1) + @ray.remote(num_cpus=2) def f(): # Assert numpy using 2 threads for it's parallelism backend. import numpy # noqa: F401 From c3e5db3ca09b9743e713abc1a78ff860c656d30d Mon Sep 17 00:00:00 2001 From: Deegue Date: Fri, 7 Jul 2023 02:51:06 +0000 Subject: [PATCH 09/25] nit --- python/ray/includes/serialization.pxi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index f3a46931e623..b8f33c2532cf 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -543,9 +543,8 @@ cdef class ArrowSerializedObject(SerializedObject): object value int64_t _total_bytes - import pyarrow as pa - def __init__(self, value): + import pyarrow as pa super(ArrowSerializedObject, self).__init__(ray_constants.OBJECT_METADATA_TYPE_ARROW) self.value = value @@ -563,6 +562,7 @@ cdef class ArrowSerializedObject(SerializedObject): @cython.wraparound(False) cdef void write_to(self, uint8_t[:] buffer) nogil: with gil: + import pyarrow as pa sink = pa.FixedSizeBufferWriter(pa.py_buffer(buffer)) writer = pa.ipc.new_stream(sink, self.value.schema) writer.write(self.value) From 03e64d60ea3736b89e0df4836019e025551383d2 Mon Sep 17 00:00:00 2001 From: Deegue Date: Fri, 7 Jul 2023 06:06:17 +0000 Subject: [PATCH 10/25] lint and fix test --- .../main/resources/test_cross_language_invocation.py | 12 +----------- python/ray/_private/serialization.py | 1 + 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py index 4ea96498e8c1..a603f991a2d2 100644 --- a/java/test/src/main/resources/test_cross_language_invocation.py +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -210,16 +210,6 @@ def py_object_store_get_and_check(table): for column_name in table.column_names: column1 = table[column_name] column2 = expected_table[column_name] - - indices = pa.compute.equal(column1, column2).to_pylist() - differing_rows = [i for i, index in enumerate(indices) if not index] - - if differing_rows: - print(f"Differences in column '{column_name}':") - for row in differing_rows: - value1 = column1[row].as_py() - value2 = column2[row].as_py() - print(f"Row {row}: {value1} != {value2}") - raise RuntimeError("Check failed, two tables are not equal!") + assert(column1.equals(column2)) return table diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index a543b3b2029e..28002ebc0142 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -273,6 +273,7 @@ def _deserialize_object(self, data, metadata, object_ref): return data.to_pybytes() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW: import pyarrow as pa + reader = pa.BufferReader(data) return pa.ipc.open_stream(reader).read_all() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE: From ede5faeef9a25ff13c03a1a42ae21485d9e6403f Mon Sep 17 00:00:00 2001 From: Deegue Date: Fri, 7 Jul 2023 08:11:06 +0000 Subject: [PATCH 11/25] lint --- java/test/src/main/resources/test_cross_language_invocation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py index a603f991a2d2..c3aa53dd15f3 100644 --- a/java/test/src/main/resources/test_cross_language_invocation.py +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -210,6 +210,6 @@ def py_object_store_get_and_check(table): for column_name in table.column_names: column1 = table[column_name] column2 = expected_table[column_name] - assert(column1.equals(column2)) + assert column1.equals(column2) return table From e57a8de78667cd39822577ae2d313131e9667136 Mon Sep 17 00:00:00 2001 From: Deegue Date: Mon, 10 Jul 2023 07:45:45 +0000 Subject: [PATCH 12/25] move outside --- python/ray/includes/serialization.pxi | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index b8f33c2532cf..a0e7cecfa289 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -2,6 +2,7 @@ from libc.string cimport memcpy from libc.stdint cimport uintptr_t, uint64_t, INT32_MAX import contextlib import cython +import pyarrow DEF MEMCOPY_THREADS = 6 @@ -544,12 +545,11 @@ cdef class ArrowSerializedObject(SerializedObject): int64_t _total_bytes def __init__(self, value): - import pyarrow as pa super(ArrowSerializedObject, self).__init__(ray_constants.OBJECT_METADATA_TYPE_ARROW) self.value = value - sink = pa.MockOutputStream() - writer = pa.ipc.new_stream(sink, self.value.schema) + sink = pyarrow.MockOutputStream() + writer = pyarrow.ipc.new_stream(sink, self.value.schema) writer.write(self.value) writer.close() self._total_bytes = sink.size() @@ -562,8 +562,7 @@ cdef class ArrowSerializedObject(SerializedObject): @cython.wraparound(False) cdef void write_to(self, uint8_t[:] buffer) nogil: with gil: - import pyarrow as pa - sink = pa.FixedSizeBufferWriter(pa.py_buffer(buffer)) - writer = pa.ipc.new_stream(sink, self.value.schema) + sink = pyarrow.FixedSizeBufferWriter(pyarrow.py_buffer(buffer)) + writer = pyarrow.ipc.new_stream(sink, self.value.schema) writer.write(self.value) writer.close() From 74c85f059ca28cf5dacd625801d07850fbbd6496 Mon Sep 17 00:00:00 2001 From: Deegue Date: Mon, 10 Jul 2023 08:56:42 +0000 Subject: [PATCH 13/25] nit --- python/ray/includes/serialization.pxi | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index a0e7cecfa289..3b48d3801d9f 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -2,7 +2,6 @@ from libc.string cimport memcpy from libc.stdint cimport uintptr_t, uint64_t, INT32_MAX import contextlib import cython -import pyarrow DEF MEMCOPY_THREADS = 6 @@ -539,6 +538,8 @@ cdef class RawSerializedObject(SerializedObject): memcpy(&buffer[0], self.value_ptr, self._total_bytes) +import pyarrow + cdef class ArrowSerializedObject(SerializedObject): cdef: object value From b96edbfb526fa93fcc12b528f997b4902cee8cb5 Mon Sep 17 00:00:00 2001 From: Deegue Date: Mon, 10 Jul 2023 09:37:26 +0000 Subject: [PATCH 14/25] nit --- python/ray/includes/serialization.pxi | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index 3b48d3801d9f..804e51f5b43a 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -538,7 +538,6 @@ cdef class RawSerializedObject(SerializedObject): memcpy(&buffer[0], self.value_ptr, self._total_bytes) -import pyarrow cdef class ArrowSerializedObject(SerializedObject): cdef: @@ -546,6 +545,8 @@ cdef class ArrowSerializedObject(SerializedObject): int64_t _total_bytes def __init__(self, value): + import pyarrow + super(ArrowSerializedObject, self).__init__(ray_constants.OBJECT_METADATA_TYPE_ARROW) self.value = value @@ -562,6 +563,8 @@ cdef class ArrowSerializedObject(SerializedObject): @cython.boundscheck(False) @cython.wraparound(False) cdef void write_to(self, uint8_t[:] buffer) nogil: + import pyarrow + with gil: sink = pyarrow.FixedSizeBufferWriter(pyarrow.py_buffer(buffer)) writer = pyarrow.ipc.new_stream(sink, self.value.schema) From 3fa10c31a0141eabaa6279f816d588c9952d9449 Mon Sep 17 00:00:00 2001 From: Deegue Date: Tue, 11 Jul 2023 04:04:50 +0000 Subject: [PATCH 15/25] test --- python/ray/_private/serialization.py | 17 ++++++++++------- python/ray/includes/serialization.pxi | 3 +-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index 9fc8d7511ea4..e70def647bfc 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -273,10 +273,13 @@ def _deserialize_object(self, data, metadata, object_ref): return b"" return data.to_pybytes() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW: - import pyarrow as pa + try: + import pyarrow + except ImportError: + pyarrow = None - reader = pa.BufferReader(data) - return pa.ipc.open_stream(reader).read_all() + reader = pyarrow.BufferReader(data) + return pyarrow.ipc.open_stream(reader).read_all() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE: obj = self._deserialize_msgpack_data(data, metadata_fields) return _actor_handle_deserializer(obj) @@ -472,14 +475,14 @@ def serialize(self, value): return RawSerializedObject(value) try: - import pyarrow as pa + import pyarrow except ImportError: - pa = None + pyarrow = None # Check whether arrow is installed. If so, use Arrow IPC format # to serialize this object, then it can also be read by Java. - if pa is not None and ( - isinstance(value, pa.Table) or isinstance(value, pa.RecordBatch) + if pyarrow is not None and ( + isinstance(value, pyarrow.Table) or isinstance(value, pyarrow.RecordBatch) ): return ArrowSerializedObject(value) diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index 804e51f5b43a..33584b2e09eb 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -563,9 +563,8 @@ cdef class ArrowSerializedObject(SerializedObject): @cython.boundscheck(False) @cython.wraparound(False) cdef void write_to(self, uint8_t[:] buffer) nogil: - import pyarrow - with gil: + import pyarrow sink = pyarrow.FixedSizeBufferWriter(pyarrow.py_buffer(buffer)) writer = pyarrow.ipc.new_stream(sink, self.value.schema) writer.write(self.value) From e390b6a2c8226bfed9972b4981b9eea5568f6dd0 Mon Sep 17 00:00:00 2001 From: Deegue Date: Wed, 12 Jul 2023 06:30:03 +0000 Subject: [PATCH 16/25] lint --- python/ray/includes/serialization.pxi | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/ray/includes/serialization.pxi b/python/ray/includes/serialization.pxi index 33584b2e09eb..3ac8d3e8f062 100644 --- a/python/ray/includes/serialization.pxi +++ b/python/ray/includes/serialization.pxi @@ -538,7 +538,6 @@ cdef class RawSerializedObject(SerializedObject): memcpy(&buffer[0], self.value_ptr, self._total_bytes) - cdef class ArrowSerializedObject(SerializedObject): cdef: object value @@ -546,7 +545,6 @@ cdef class ArrowSerializedObject(SerializedObject): def __init__(self, value): import pyarrow - super(ArrowSerializedObject, self).__init__(ray_constants.OBJECT_METADATA_TYPE_ARROW) self.value = value From c02b243c07a5ae4118278f19ae5249327a15a062 Mon Sep 17 00:00:00 2001 From: Deegue Date: Thu, 20 Jul 2023 03:35:47 +0000 Subject: [PATCH 17/25] move inside --- python/ray/data/_internal/arrow_ops/transform_pyarrow.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 8dd265a4c7af..f4eeb5b6ed20 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -1,10 +1,5 @@ from typing import TYPE_CHECKING, List, Union -try: - import pyarrow -except ImportError: - pyarrow = None - if TYPE_CHECKING: from ray.data._internal.sort import SortKeyT @@ -26,6 +21,9 @@ def take_table( extension arrays. This is exposed as a static method for easier use on intermediate tables, not underlying an ArrowBlockAccessor. """ + + import pyarrow + from ray.air.util.transform_pyarrow import ( _concatenate_extension_column, _is_column_extension_type, From e6ae146afa89757e6436d5f646ac02b423af8b7e Mon Sep 17 00:00:00 2001 From: Deegue Date: Thu, 20 Jul 2023 06:09:13 +0000 Subject: [PATCH 18/25] nit --- python/ray/data/_internal/arrow_ops/transform_pyarrow.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index f4eeb5b6ed20..787f56992056 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -23,7 +23,6 @@ def take_table( """ import pyarrow - from ray.air.util.transform_pyarrow import ( _concatenate_extension_column, _is_column_extension_type, @@ -49,7 +48,6 @@ def unify_schemas( """Version of `pyarrow.unify_schemas()` which also handles checks for variable-shaped tensors in the given schemas.""" import pyarrow as pa - from ray.air.util.tensor_extensions.arrow import ( ArrowTensorType, ArrowVariableShapedTensorType, @@ -123,6 +121,7 @@ def _concatenate_chunked_arrays(arrs: "pyarrow.ChunkedArray") -> "pyarrow.Chunke """ Concatenate provided chunked arrays into a single chunked array. """ + import pyarrow from ray.data.extensions import ArrowTensorType, ArrowVariableShapedTensorType # Single flat list of chunks across all chunked arrays. @@ -256,6 +255,7 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table": def concat_and_sort( blocks: List["pyarrow.Table"], key: "SortKeyT", descending: bool ) -> "pyarrow.Table": + import pyarrow ret = concat(blocks) indices = pyarrow.compute.sort_indices(ret, sort_keys=key) return take_table(ret, indices) @@ -267,6 +267,7 @@ def combine_chunks(table: "pyarrow.Table") -> "pyarrow.Table": This will create a new table by combining the chunks the input table has. """ + import pyarrow from ray.air.util.transform_pyarrow import ( _concatenate_extension_column, _is_column_extension_type, From 6aa4ad8fc6fdf099cc1a328f01e2d4b7ae47a8cb Mon Sep 17 00:00:00 2001 From: Zhi Lin Date: Fri, 21 Jul 2023 15:23:52 +0000 Subject: [PATCH 19/25] revert import in transform_pyarrow Signed-off-by: Zhi Lin --- python/ray/data/_internal/arrow_ops/transform_pyarrow.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 787f56992056..5b866cf7f435 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -1,5 +1,10 @@ from typing import TYPE_CHECKING, List, Union +try: + import pyarrow +except ImportError: + pyarrow = None + if TYPE_CHECKING: from ray.data._internal.sort import SortKeyT @@ -22,7 +27,6 @@ def take_table( intermediate tables, not underlying an ArrowBlockAccessor. """ - import pyarrow from ray.air.util.transform_pyarrow import ( _concatenate_extension_column, _is_column_extension_type, @@ -121,7 +125,6 @@ def _concatenate_chunked_arrays(arrs: "pyarrow.ChunkedArray") -> "pyarrow.Chunke """ Concatenate provided chunked arrays into a single chunked array. """ - import pyarrow from ray.data.extensions import ArrowTensorType, ArrowVariableShapedTensorType # Single flat list of chunks across all chunked arrays. @@ -255,7 +258,6 @@ def concat(blocks: List["pyarrow.Table"]) -> "pyarrow.Table": def concat_and_sort( blocks: List["pyarrow.Table"], key: "SortKeyT", descending: bool ) -> "pyarrow.Table": - import pyarrow ret = concat(blocks) indices = pyarrow.compute.sort_indices(ret, sort_keys=key) return take_table(ret, indices) @@ -267,7 +269,6 @@ def combine_chunks(table: "pyarrow.Table") -> "pyarrow.Table": This will create a new table by combining the chunks the input table has. """ - import pyarrow from ray.air.util.transform_pyarrow import ( _concatenate_extension_column, _is_column_extension_type, From c47572b3092bb23918f205eca70078c5b1d2042a Mon Sep 17 00:00:00 2001 From: Zhi Lin Date: Mon, 24 Jul 2023 20:29:03 +0000 Subject: [PATCH 20/25] comment added part and see if ci pass Signed-off-by: Zhi Lin --- python/ray/_private/serialization.py | 40 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index e70def647bfc..f3104678ebb4 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -8,7 +8,7 @@ import ray.cloudpickle as pickle from ray._private import ray_constants from ray._raylet import ( - ArrowSerializedObject, + # ArrowSerializedObject, MessagePackSerializedObject, MessagePackSerializer, ObjectRefGenerator, @@ -272,14 +272,14 @@ def _deserialize_object(self, data, metadata, object_ref): if data is None: return b"" return data.to_pybytes() - elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW: - try: - import pyarrow - except ImportError: - pyarrow = None - - reader = pyarrow.BufferReader(data) - return pyarrow.ipc.open_stream(reader).read_all() + # elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW: + # try: + # import pyarrow + # except ImportError: + # pyarrow = None + + # reader = pyarrow.BufferReader(data) + # return pyarrow.ipc.open_stream(reader).read_all() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE: obj = self._deserialize_msgpack_data(data, metadata_fields) return _actor_handle_deserializer(obj) @@ -474,16 +474,16 @@ def serialize(self, value): # that this object can also be read by Java. return RawSerializedObject(value) - try: - import pyarrow - except ImportError: - pyarrow = None - - # Check whether arrow is installed. If so, use Arrow IPC format - # to serialize this object, then it can also be read by Java. - if pyarrow is not None and ( - isinstance(value, pyarrow.Table) or isinstance(value, pyarrow.RecordBatch) - ): - return ArrowSerializedObject(value) + # try: + # import pyarrow + # except ImportError: + # pyarrow = None + + # # Check whether arrow is installed. If so, use Arrow IPC format + # # to serialize this object, then it can also be read by Java. + # if pyarrow is not None and ( + # isinstance(value, pyarrow.Table) or isinstance(value, pyarrow.RecordBatch) + # ): + # return ArrowSerializedObject(value) return self._serialize_to_msgpack(value) From 3da0729e69dda2227e3c12e4cb6c39981dcb10b7 Mon Sep 17 00:00:00 2001 From: Zhi Lin Date: Tue, 25 Jul 2023 09:00:07 +0000 Subject: [PATCH 21/25] see if import breaks ci Signed-off-by: Zhi Lin --- python/ray/_private/serialization.py | 2 +- python/ray/data/_internal/arrow_ops/transform_pyarrow.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index f3104678ebb4..51e05273c66d 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -8,7 +8,7 @@ import ray.cloudpickle as pickle from ray._private import ray_constants from ray._raylet import ( - # ArrowSerializedObject, + ArrowSerializedObject, MessagePackSerializedObject, MessagePackSerializer, ObjectRefGenerator, diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 5b866cf7f435..72c4acd618e8 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -52,6 +52,7 @@ def unify_schemas( """Version of `pyarrow.unify_schemas()` which also handles checks for variable-shaped tensors in the given schemas.""" import pyarrow as pa + from ray.air.util.tensor_extensions.arrow import ( ArrowTensorType, ArrowVariableShapedTensorType, From 5872b7612a8766e9cd2c12ebcae036693f158be1 Mon Sep 17 00:00:00 2001 From: Zhi Lin Date: Tue, 25 Jul 2023 10:24:10 +0000 Subject: [PATCH 22/25] add back deser part Signed-off-by: Zhi Lin --- python/ray/_private/serialization.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index 51e05273c66d..c51c04edd914 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -272,14 +272,14 @@ def _deserialize_object(self, data, metadata, object_ref): if data is None: return b"" return data.to_pybytes() - # elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW: - # try: - # import pyarrow - # except ImportError: - # pyarrow = None - - # reader = pyarrow.BufferReader(data) - # return pyarrow.ipc.open_stream(reader).read_all() + elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ARROW: + try: + import pyarrow + except ImportError: + pyarrow = None + + reader = pyarrow.BufferReader(data) + return pyarrow.ipc.open_stream(reader).read_all() elif metadata_fields[0] == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE: obj = self._deserialize_msgpack_data(data, metadata_fields) return _actor_handle_deserializer(obj) From 5def2f9a51706faf312d91d99a94115f836e1fef Mon Sep 17 00:00:00 2001 From: Zhi Lin Date: Tue, 25 Jul 2023 12:07:47 +0000 Subject: [PATCH 23/25] add back serialize import Signed-off-by: Zhi Lin --- python/ray/_private/serialization.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index c51c04edd914..0ecc1d6cbefa 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -474,13 +474,13 @@ def serialize(self, value): # that this object can also be read by Java. return RawSerializedObject(value) - # try: - # import pyarrow - # except ImportError: - # pyarrow = None + try: + import pyarrow + except ImportError: + pyarrow = None - # # Check whether arrow is installed. If so, use Arrow IPC format - # # to serialize this object, then it can also be read by Java. + # Check whether arrow is installed. If so, use Arrow IPC format + # to serialize this object, then it can also be read by Java. # if pyarrow is not None and ( # isinstance(value, pyarrow.Table) or isinstance(value, pyarrow.RecordBatch) # ): From 371224d1bf692f9cbb47c8b8705479132a63e558 Mon Sep 17 00:00:00 2001 From: Zhi Lin Date: Tue, 25 Jul 2023 15:19:07 +0000 Subject: [PATCH 24/25] return table if indices is empty Signed-off-by: Zhi Lin --- python/ray/_private/serialization.py | 8 ++++---- python/ray/data/_internal/arrow_ops/transform_pyarrow.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/ray/_private/serialization.py b/python/ray/_private/serialization.py index 0ecc1d6cbefa..e70def647bfc 100644 --- a/python/ray/_private/serialization.py +++ b/python/ray/_private/serialization.py @@ -481,9 +481,9 @@ def serialize(self, value): # Check whether arrow is installed. If so, use Arrow IPC format # to serialize this object, then it can also be read by Java. - # if pyarrow is not None and ( - # isinstance(value, pyarrow.Table) or isinstance(value, pyarrow.RecordBatch) - # ): - # return ArrowSerializedObject(value) + if pyarrow is not None and ( + isinstance(value, pyarrow.Table) or isinstance(value, pyarrow.RecordBatch) + ): + return ArrowSerializedObject(value) return self._serialize_to_msgpack(value) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 72c4acd618e8..83efce7da5cc 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -31,7 +31,8 @@ def take_table( _concatenate_extension_column, _is_column_extension_type, ) - + if len(indices) == 0: + return table if any(_is_column_extension_type(col) for col in table.columns): new_cols = [] for col in table.columns: From 94ed413a9b9d93156e38b557f734a43444bb6249 Mon Sep 17 00:00:00 2001 From: Zhi Lin Date: Wed, 26 Jul 2023 13:29:31 +0000 Subject: [PATCH 25/25] add empty check for take_table Signed-off-by: Zhi Lin --- python/ray/data/_internal/arrow_ops/transform_pyarrow.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py index 83efce7da5cc..f4ebb8ac7d05 100644 --- a/python/ray/data/_internal/arrow_ops/transform_pyarrow.py +++ b/python/ray/data/_internal/arrow_ops/transform_pyarrow.py @@ -31,8 +31,11 @@ def take_table( _concatenate_extension_column, _is_column_extension_type, ) - if len(indices) == 0: + + if table.num_rows == 0: return table + if len(indices) == 0: + return pyarrow.Table.from_pydict({}) if any(_is_column_extension_type(col) for col in table.columns): new_cols = [] for col in table.columns: