Skip to content

Commit

Permalink
Revert "[Datasets] Arrow 7.0.0+ Support: Use Arrow IPC format for pic…
Browse files Browse the repository at this point in the history
…kling Arrow data to circumvent slice view buffer truncation bug. (#29055)" (#29138)

This reverts commit c1d62d4.

This is breaking the Windows build: https://buildkite.com/ray-project/oss-ci-build-branch/builds/365#0183ada0-8558-4502-9298-7e0bed873e23
  • Loading branch information
clarkzinzow authored Oct 7, 2022
1 parent 2b5f041 commit 081ce2f
Show file tree
Hide file tree
Showing 24 changed files with 105 additions and 646 deletions.
11 changes: 6 additions & 5 deletions python/ray/air/util/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def _concatenate_extension_column(ca: "pyarrow.ChunkedArray") -> "pyarrow.Array"
raise ValueError("Chunked array isn't an extension array: {ca}")

if ca.num_chunks == 0:
# Create empty storage array.
storage = pyarrow.array([], type=ca.type.storage_type)
else:
storage = pyarrow.concat_arrays([c.storage for c in ca.chunks])
# No-op for no-chunk chunked arrays, since there's nothing to concatenate.
return ca

return ca.type.__arrow_ext_class__().from_storage(ca.type, storage)
chunk = ca.chunk(0)
return type(chunk).from_storage(
chunk.type, pyarrow.concat_arrays([c.storage for c in ca.chunks])
)
14 changes: 14 additions & 0 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import ray
from ray.data._internal.arrow_serialization import (
_register_arrow_json_parseoptions_serializer,
_register_arrow_json_readoptions_serializer,
)
from ray.data._internal.compute import ActorPoolStrategy
from ray.data._internal.progress_bar import set_progress_bars
from ray.data.dataset import Dataset
Expand Down Expand Up @@ -32,6 +37,15 @@
read_tfrecords,
)

# Register custom Arrow JSON ReadOptions and ParseOptions serializer after worker has
# initialized.
if ray.is_initialized():
_register_arrow_json_readoptions_serializer()
_register_arrow_json_parseoptions_serializer()
else:
pass
# ray._internal.worker._post_init_hooks.append(_register_arrow_json_readoptions_serializer)

__all__ = [
"ActorPoolStrategy",
"Dataset",
Expand Down
17 changes: 9 additions & 8 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@

from ray.data._internal.sort import SortKeyT


T = TypeVar("T")


Expand Down Expand Up @@ -176,7 +175,7 @@ def _build_tensor_row(row: ArrowRow) -> np.ndarray:
# Getting an item in a tensor column automatically does a NumPy conversion.
return row[VALUE_COL_NAME][0]

def slice(self, start: int, end: int, copy: bool = False) -> "pyarrow.Table":
def slice(self, start: int, end: int, copy: bool) -> "pyarrow.Table":
view = self._table.slice(start, end - start)
if copy:
view = _copy_table(view)
Expand Down Expand Up @@ -208,10 +207,10 @@ def to_numpy(
arrays = []
for column in columns:
array = self._table[column]
if _is_column_extension_type(array):
array = _concatenate_extension_column(array)
elif array.num_chunks == 0:
if array.num_chunks == 0:
array = pyarrow.array([], type=array.type)
elif _is_column_extension_type(array):
array = _concatenate_extension_column(array)
else:
array = array.combine_chunks()
arrays.append(array.to_numpy(zero_copy_only=False))
Expand Down Expand Up @@ -395,9 +394,11 @@ def sort_and_partition(
bounds = np.searchsorted(table[col], boundaries)
last_idx = 0
for idx in bounds:
partitions.append(table.slice(last_idx, idx - last_idx))
# Slices need to be copied to avoid including the base table
# during serialization.
partitions.append(_copy_table(table.slice(last_idx, idx - last_idx)))
last_idx = idx
partitions.append(table.slice(last_idx))
partitions.append(_copy_table(table.slice(last_idx)))
return partitions

def combine(self, key: KeyFn, aggs: Tuple[AggregateFn]) -> Block[ArrowRow]:
Expand Down Expand Up @@ -443,7 +444,7 @@ def iter_groups() -> Iterator[Tuple[KeyType, Block]]:
except StopIteration:
next_row = None
break
yield next_key, self.slice(start, end)
yield next_key, self.slice(start, end, copy=False)
start = end
except StopIteration:
break
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_ops/transform_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


if TYPE_CHECKING:
from ray.data._internal.sort import SortKeyT
from ray.data.impl.sort import SortKeyT

pl = None

Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
pyarrow = None

if TYPE_CHECKING:
from ray.data._internal.sort import SortKeyT
from ray.data.impl.sort import SortKeyT


def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.Table":
Expand Down
231 changes: 27 additions & 204 deletions python/ray/data/_internal/arrow_serialization.py
Original file line number Diff line number Diff line change
@@ -1,242 +1,65 @@
import functools
import os
from typing import List, Callable, TYPE_CHECKING

if TYPE_CHECKING:
import pyarrow

RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION = (
"RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION"
)
RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION = (
"RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION"
)


def _register_custom_datasets_serializers(serialization_context):
try:
import pyarrow as pa # noqa: F401
except ModuleNotFoundError:
# No pyarrow installed so not using Arrow, so no need for custom serializers.
return

# Register all custom serializers required by Datasets.
_register_arrow_data_serializer(serialization_context)
_register_arrow_json_readoptions_serializer(serialization_context)
_register_arrow_json_parseoptions_serializer(serialization_context)

def _register_arrow_json_readoptions_serializer():
import ray

# Register custom Arrow JSON ReadOptions serializer to workaround it not being picklable
# in Arrow < 8.0.0.
def _register_arrow_json_readoptions_serializer(serialization_context):
if (
os.environ.get(
RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION,
"0",
)
== "1"
):
import logging

logger = logging.getLogger(__name__)
logger.info("Disabling custom Arrow JSON ReadOptions serialization.")
return

import pyarrow.json as pajson
try:
import pyarrow.json as pajson
except ModuleNotFoundError:
return

serialization_context._register_cloudpickle_serializer(
ray.util.register_serializer(
pajson.ReadOptions,
custom_serializer=lambda opts: (opts.use_threads, opts.block_size),
custom_deserializer=lambda args: pajson.ReadOptions(*args),
serializer=lambda opts: (opts.use_threads, opts.block_size),
deserializer=lambda args: pajson.ReadOptions(*args),
)


def _register_arrow_json_parseoptions_serializer(serialization_context):
def _register_arrow_json_parseoptions_serializer():
import ray

if (
os.environ.get(
RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION,
"0",
)
== "1"
):
import logging

logger = logging.getLogger(__name__)
logger.info("Disabling custom Arrow JSON ParseOptions serialization.")
return

import pyarrow.json as pajson
try:
import pyarrow.json as pajson
except ModuleNotFoundError:
return

serialization_context._register_cloudpickle_serializer(
ray.util.register_serializer(
pajson.ParseOptions,
custom_serializer=lambda opts: (
serializer=lambda opts: (
opts.explicit_schema,
opts.newlines_in_values,
opts.unexpected_field_behavior,
),
custom_deserializer=lambda args: pajson.ParseOptions(*args),
)


# Register custom Arrow data serializer to work around zero-copy slice pickling bug.
# See https://issues.apache.org/jira/browse/ARROW-10739.
def _register_arrow_data_serializer(serialization_context):
"""Custom reducer for Arrow data that works around a zero-copy slicing pickling
bug by using the Arrow IPC format for the underlying serialization.
Background:
Arrow has both array-level slicing and buffer-level slicing; both are zero-copy,
but the former has a serialization bug where the entire buffer is serialized
instead of just the slice, while the latter's serialization works as expected
and only serializes the slice of the buffer. I.e., array-level slicing doesn't
propagate the slice down to the buffer when serializing the array.
All that these copy methods do is, at serialization time, take the array-level
slicing and translate them to buffer-level slicing, so only the buffer slice is
sent over the wire instead of the entire buffer.
See https://issues.apache.org/jira/browse/ARROW-10739.
"""
import pyarrow as pa

if os.environ.get(RAY_DISABLE_CUSTOM_ARROW_DATA_SERIALIZATION, "0") == "1":
return

# Register custom reducer for Arrow Arrays.
array_types = _get_arrow_array_types()
for array_type in array_types:
serialization_context._register_cloudpickle_reducer(
array_type, _arrow_array_reduce
)
# Register custom reducer for Arrow ChunkedArrays.
serialization_context._register_cloudpickle_reducer(
pa.ChunkedArray, _arrow_chunkedarray_reduce
)
# Register custom reducer for Arrow RecordBatches.
serialization_context._register_cloudpickle_reducer(
pa.RecordBatch, _arrow_recordbatch_reduce
deserializer=lambda args: pajson.ParseOptions(*args),
)
# Register custom reducer for Arrow Tables.
serialization_context._register_cloudpickle_reducer(pa.Table, _arrow_table_reduce)


def _get_arrow_array_types() -> List[type]:
"""Get all Arrow array types that we want to register a custom serializer for."""
import pyarrow as pa
from ray.data.extensions import ArrowTensorArray, ArrowVariableShapedTensorArray

array_types = [
pa.lib.NullArray,
pa.lib.BooleanArray,
pa.lib.UInt8Array,
pa.lib.UInt16Array,
pa.lib.UInt32Array,
pa.lib.UInt64Array,
pa.lib.Int8Array,
pa.lib.Int16Array,
pa.lib.Int32Array,
pa.lib.Int64Array,
pa.lib.Date32Array,
pa.lib.Date64Array,
pa.lib.TimestampArray,
pa.lib.Time32Array,
pa.lib.Time64Array,
pa.lib.DurationArray,
pa.lib.HalfFloatArray,
pa.lib.FloatArray,
pa.lib.DoubleArray,
pa.lib.ListArray,
pa.lib.LargeListArray,
pa.lib.MapArray,
pa.lib.FixedSizeListArray,
pa.lib.UnionArray,
pa.lib.BinaryArray,
pa.lib.StringArray,
pa.lib.LargeBinaryArray,
pa.lib.LargeStringArray,
pa.lib.DictionaryArray,
pa.lib.FixedSizeBinaryArray,
pa.lib.Decimal128Array,
pa.lib.Decimal256Array,
pa.lib.StructArray,
pa.lib.ExtensionArray,
ArrowTensorArray,
ArrowVariableShapedTensorArray,
]
try:
array_types.append(pa.lib.MonthDayNanoIntervalArray)
except AttributeError:
# MonthDayNanoIntervalArray doesn't exist on older pyarrow versions.
pass
return array_types


def _arrow_array_reduce(a: "pyarrow.Array"):
"""Custom reducer for Arrow arrays that works around a zero-copy slicing pickling
bug by using the Arrow IPC format for the underlying serialization.
"""
import pyarrow as pa

batch = pa.RecordBatch.from_arrays([a], [""])
restore_recordbatch, serialized = _arrow_recordbatch_reduce(batch)

return functools.partial(_restore_array, restore_recordbatch), serialized


def _restore_array(
restore_recordbatch: Callable[[bytes], "pyarrow.RecordBatch"], buf: bytes
) -> "pyarrow.Array":
"""Restore a serialized Arrow Array."""
return restore_recordbatch(buf).column(0)


def _arrow_chunkedarray_reduce(a: "pyarrow.ChunkedArray"):
"""Custom reducer for Arrow ChunkedArrays that works around a zero-copy slicing
pickling bug by using the Arrow IPC format for the underlying serialization.
"""
import pyarrow as pa

table = pa.Table.from_arrays([a], names=[""])
restore_table, serialized = _arrow_table_reduce(table)
return functools.partial(_restore_chunkedarray, restore_table), serialized


def _restore_chunkedarray(
restore_table: Callable[[bytes], "pyarrow.Table"], buf: bytes
) -> "pyarrow.ChunkedArray":
"""Restore a serialized Arrow ChunkedArray."""
return restore_table(buf).column(0)


def _arrow_recordbatch_reduce(batch: "pyarrow.RecordBatch"):
"""Custom reducer for Arrow RecordBatch that works around a zero-copy slicing
pickling bug by using the Arrow IPC format for the underlying serialization.
"""
from pyarrow.ipc import RecordBatchStreamWriter
from pyarrow.lib import BufferOutputStream

output_stream = BufferOutputStream()
with RecordBatchStreamWriter(output_stream, schema=batch.schema) as wr:
wr.write_batch(batch)
return _restore_recordbatch, (output_stream.getvalue(),)


def _restore_recordbatch(buf: bytes) -> "pyarrow.RecordBatch":
"""Restore a serialized Arrow RecordBatch."""
from pyarrow.ipc import RecordBatchStreamReader

with RecordBatchStreamReader(buf) as reader:
return reader.read_next_batch()


def _arrow_table_reduce(table: "pyarrow.Table"):
"""Custom reducer for Arrow Table that works around a zero-copy slicing pickling
bug by using the Arrow IPC format for the underlying serialization.
"""
from pyarrow.ipc import RecordBatchStreamWriter
from pyarrow.lib import BufferOutputStream

output_stream = BufferOutputStream()
with RecordBatchStreamWriter(output_stream, schema=table.schema) as wr:
wr.write_table(table)
return _restore_table, (output_stream.getvalue(),)


def _restore_table(buf: bytes) -> "pyarrow.Table":
"""Restore a serialized Arrow Table."""
from pyarrow.ipc import RecordBatchStreamReader

with RecordBatchStreamReader(buf) as reader:
return reader.read_all()
1 change: 0 additions & 1 deletion python/ray/data/_internal/delegating_block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def build(self) -> Block:
if self._builder is None:
if self._empty_block is not None:
self._builder = BlockAccessor.for_block(self._empty_block).builder()
self._builder.add_block(self._empty_block)
else:
self._builder = ArrowBlockBuilder()
return self._builder.build()
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _build_tensor_row(row: PandasRow) -> np.ndarray:
tensor = tensor.to_numpy()
return tensor

def slice(self, start: int, end: int, copy: bool = False) -> "pandas.DataFrame":
def slice(self, start: int, end: int, copy: bool) -> "pandas.DataFrame":
view = self._table[start:end]
view.reset_index(drop=True, inplace=True)
if copy:
Expand Down
Loading

0 comments on commit 081ce2f

Please sign in to comment.