Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed Oct 6, 2022
1 parent 260dce9 commit dc64bb1
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 17 deletions.
8 changes: 6 additions & 2 deletions python/ray/air/tests/test_tensor_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@ def test_scalar_tensor_array_roundtrip():


def test_arrow_variable_shaped_tensor_array_validation():
# Test homogeneous-typed tensor raises ValueError.
# Test homogeneous-shaped tensor raises ValueError.
with pytest.raises(ValueError):
ArrowVariableShapedTensorArray.from_numpy(np.ones((3, 2, 2)))

# Test tensor elements with differing dimensions raises ValueError.
with pytest.raises(ValueError):
ArrowVariableShapedTensorArray.from_numpy([np.ones((2, 2)), np.ones((3, 3, 3))])

# Test arbitrary object raises ValueError.
with pytest.raises(ValueError):
ArrowVariableShapedTensorArray.from_numpy(object())
Expand Down Expand Up @@ -402,7 +406,7 @@ def test_tensor_array_concat(shape1, shape2):
assert ta.dtype.element_shape == shape1[1:]
np.testing.assert_array_equal(ta.to_numpy(), np.concatenate([a1, a2]))
else:
assert ta.dtype.element_shape is None
assert ta.dtype.element_shape == (None,) * (len(shape1) - 1)
for arr, expected in zip(
ta.to_numpy(), np.array([e for a in [a1, a2] for e in a], dtype=object)
):
Expand Down
28 changes: 22 additions & 6 deletions python/ray/air/util/tensor_extensions/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,15 @@ class ArrowVariableShapedTensorType(pa.PyExtensionType):
https://arrow.apache.org/docs/python/extending_types.html#defining-extension-types-user-defined-types
"""

def __init__(self, dtype: pa.DataType):
def __init__(self, dtype: pa.DataType, ndim: int):
"""
Construct the Arrow extension type for array of heterogeneous-shaped tensors.
Args:
dtype: pyarrow dtype of tensor elements.
ndim: The number of dimensions in the tensor elements.
"""
self._ndim = ndim
super().__init__(
pa.struct([("data", pa.list_(dtype)), ("shape", pa.list_(pa.int64()))])
)
Expand All @@ -321,14 +323,19 @@ def to_pandas_dtype(self):
from ray.air.util.tensor_extensions.pandas import TensorDtype

return TensorDtype(
None,
(None,) * self.ndim,
self.storage_type["data"].type.value_type.to_pandas_dtype(),
)

@property
def ndim(self) -> int:
"""Return the number of dimensions in the tensor elements."""
return self._ndim

def __reduce__(self):
return (
ArrowVariableShapedTensorType,
(self.storage_type["data"].type.value_type,),
(self.storage_type["data"].type.value_type, self._ndim),
)

def __arrow_ext_class__(self):
Expand All @@ -343,7 +350,7 @@ def __arrow_ext_class__(self):

def __str__(self) -> str:
dtype = self.storage_type["data"].type.value_type
return f"ArrowVariableShapedTensorType(dtype={dtype})"
return f"ArrowVariableShapedTensorType(dtype={dtype}, ndim={self.ndim})"

def __repr__(self) -> str:
return str(self)
Expand All @@ -357,7 +364,8 @@ class ArrowVariableShapedTensorArray(pa.ExtensionArray):
This is the Arrow side of TensorArray for tensor elements that have differing
shapes. Note that this extension only supports non-ragged tensor elements; i.e.,
when considering each tensor element in isolation, they must have a well-defined
shape.
shape. This extension also only supports tensor elements that all have the same
number of dimensions.
See Arrow docs for customizing extension arrays:
https://arrow.apache.org/docs/python/extending_types.html#custom-extension-array-class
Expand Down Expand Up @@ -443,8 +451,16 @@ def from_numpy(

# Whether all subndarrays are contiguous views of the same ndarray.
shapes, sizes, raveled = [], [], []
ndim = None
for a in arr:
a = np.asarray(a)
if ndim is not None and a.ndim != ndim:
raise ValueError(
"ArrowVariableShapedTensorArray only supports tensor elements that "
"all have the same number of dimensinos, but got tensor elements "
f"with dimensions: {ndim}, {a.ndim}"
)
ndim = a.ndim
shapes.append(a.shape)
sizes.append(a.size)
# Convert to 1D array view; this should be zero-copy in the common case.
Expand Down Expand Up @@ -494,7 +510,7 @@ def from_numpy(
[data_array, shape_array],
["data", "shape"],
)
type_ = ArrowVariableShapedTensorType(pa_dtype)
type_ = ArrowVariableShapedTensorType(pa_dtype, ndim)
return pa.ExtensionArray.from_storage(type_, storage)

def _to_numpy(self, index: Optional[int] = None, zero_copy_only: bool = False):
Expand Down
12 changes: 6 additions & 6 deletions python/ray/air/util/tensor_extensions/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ class TensorDtype(pd.api.extensions.ExtensionDtype):
# https://github.com/CODAIT/text-extensions-for-pandas/issues/166
base = None

def __init__(self, shape: Optional[Tuple[int, ...]], dtype: np.dtype):
def __init__(self, shape: Tuple[Optional[int], ...], dtype: np.dtype):
self._shape = shape
self._dtype = dtype

Expand All @@ -308,8 +308,8 @@ def element_dtype(self):
@property
def element_shape(self):
"""
The shape of the underlying tensor elements. This will be None if the
corresponding TensorArray for this TensorDtype holds variable-shaped tensor
The shape of the underlying tensor elements. This will be a tuple of Nones if
the corresponding TensorArray for this TensorDtype holds variable-shaped tensor
elements.
"""
return self._shape
Expand All @@ -320,7 +320,7 @@ def is_variable_shaped(self):
Whether the corresponding TensorArray for this TensorDtype holds variable-shaped
tensor elements.
"""
return self.shape is None
return all(dim_size is None for dim_size in self.shape)

@property
def name(self) -> str:
Expand Down Expand Up @@ -384,7 +384,7 @@ def construct_from_string(cls, string: str):
)
# Upstream code uses exceptions as part of its normal control flow and
# will pass this method bogus class names.
regex = r"^TensorDtype\(shape=((?:\((?:\d+,?\s?)*\))|(?:None)), dtype=(\w+)\)$"
regex = r"^TensorDtype\(shape=(\((?:(?:\d+|None),?\s?)*\)), dtype=(\w+)\)$"
m = re.search(regex, string)
err_msg = (
f"Cannot construct a '{cls.__name__}' from '{string}'; expected a string "
Expand Down Expand Up @@ -891,7 +891,7 @@ def dtype(self) -> pd.api.extensions.ExtensionDtype:
# A tensor is only considered variable-shaped if it's non-empty, so no
# non-empty check is needed here.
dtype = self._tensor[0].dtype
shape = None
shape = (None,) * self._tensor[0].ndim
else:
dtype = self.numpy_dtype
shape = self.numpy_shape[1:]
Expand Down
11 changes: 10 additions & 1 deletion python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,17 @@ def numpy_to_block(

@staticmethod
def _build_tensor_row(row: ArrowRow) -> np.ndarray:
col = row[VALUE_COL_NAME]
chunk_idx = 0
while True:
if chunk_idx >= col.num_chunks:
# All empty chunks, return empty ndarray.
return np.array([], col.type.storage_type.value_type.to_pandas_dtype())
if len(col.chunk(chunk_idx)) > 0:
break
chunk_idx += 1
# Getting an item in a tensor column automatically does a NumPy conversion.
return row[VALUE_COL_NAME][0]
return col.chunk(chunk_idx)[0]

def slice(self, start: int, end: int, copy: bool = False) -> "pyarrow.Table":
view = self._table.slice(start, end - start)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def sort(table: "pyarrow.Table", key: "SortKeyT", descending: bool) -> "pyarrow.
import pyarrow.compute as pac

indices = pac.sort_indices(table, sort_keys=key)
return table.take(indices)
return take_table(table, indices)


def take_table(
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def test_tensors_inferred_from_map(ray_start_regular_shared):
)
assert str(ds) == (
"Dataset(num_blocks=4, num_rows=16, "
"schema={a: TensorDtype(shape=None, dtype=float64)})"
"schema={a: TensorDtype(shape=(None, None), dtype=float64)})"
)


Expand Down

0 comments on commit dc64bb1

Please sign in to comment.