Skip to content

Commit

Permalink
Change offsets dtype to int64 and change to LargeList for Arrow(Varia…
Browse files Browse the repository at this point in the history
…bleShaped)TensorArray

Signed-off-by: Peter Wang <[email protected]>
  • Loading branch information
Peter Wang committed Jun 4, 2024
1 parent 32cddae commit 1e00179
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
8 changes: 8 additions & 0 deletions python/ray/air/tests/test_tensor_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,14 @@ def test_variable_shaped_tensor_array_uniform_dim():
np.testing.assert_array_equal(a, expected)


def test_large_arrow_tensor_array():
test_arr = np.ones((1000, 550), dtype=np.uint8)
ta = ArrowTensorArray.from_numpy([test_arr] * 4000)
assert len(ta) == 4000
for arr in ta:
assert arr.as_py().shape == (1000, 550)


if __name__ == "__main__":
import sys

Expand Down
14 changes: 8 additions & 6 deletions python/ray/air/util/tensor_extensions/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self, shape: Tuple[int, ...], dtype: pa.DataType):
dtype: pyarrow dtype of tensor elements.
"""
self._shape = shape
super().__init__(pa.list_(dtype), "ray.data.arrow_tensor")
super().__init__(pa.large_list(dtype), "ray.data.arrow_tensor")

@property
def shape(self):
Expand Down Expand Up @@ -316,7 +316,7 @@ class ArrowTensorArray(_ArrowTensorScalarIndexingMixin, pa.ExtensionArray):
https://arrow.apache.org/docs/python/extending_types.html#custom-extension-array-class
"""

OFFSET_DTYPE = np.int32
OFFSET_DTYPE = np.int64

@classmethod
def from_numpy(
Expand Down Expand Up @@ -414,7 +414,7 @@ def _from_numpy(
)

storage = pa.Array.from_buffers(
pa.list_(pa_dtype),
pa.large_list(pa_dtype),
outer_len,
[None, offset_buffer],
children=[data_array],
Expand Down Expand Up @@ -612,7 +612,9 @@ def __init__(self, dtype: pa.DataType, ndim: int):
"""
self._ndim = ndim
super().__init__(
pa.struct([("data", pa.list_(dtype)), ("shape", pa.list_(pa.int64()))]),
pa.struct(
[("data", pa.large_list(dtype)), ("shape", pa.list_(pa.int64()))]
),
"ray.data.arrow_variable_shaped_tensor",
)

Expand Down Expand Up @@ -719,7 +721,7 @@ class ArrowVariableShapedTensorArray(
https://arrow.apache.org/docs/python/extending_types.html#custom-extension-array-class
"""

OFFSET_DTYPE = np.int32
OFFSET_DTYPE = np.int64

@classmethod
def from_numpy(
Expand Down Expand Up @@ -809,7 +811,7 @@ def from_numpy(
# corresponds to a tensor element.
size_offsets = np.insert(size_offsets, 0, 0)
offset_array = pa.array(size_offsets)
data_array = pa.ListArray.from_arrays(offset_array, value_array)
data_array = pa.LargeListArray.from_arrays(offset_array, value_array)
# We store the tensor element shapes so we can reconstruct each tensor when
# converting back to NumPy ndarrays.
shape_array = pa.array(shapes)
Expand Down

0 comments on commit 1e00179

Please sign in to comment.