Skip to content

Commit

Permalink
[data] Fix ragged tensor conversion with map() (ray-project#35419)
Browse files Browse the repository at this point in the history
From ray-project#35143, we found a map() case that is not covered in our numpy support test cases.

Signed-off-by: e428265 <[email protected]>
  • Loading branch information
ericl authored and arvind-chandra committed Aug 31, 2023
1 parent c15c35e commit 1447eb7
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 24 deletions.
2 changes: 1 addition & 1 deletion python/ray/air/util/tensor_extensions/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def from_numpy(
raise ValueError(
"ArrowVariableShapedTensorArray only supports tensor elements that "
"all have the same number of dimensions, but got tensor elements "
f"with dimensions: {ndim}, {a.ndim}"
f"with dimensions: {ndim}, {a.ndim}: {arr}"
)
ndim = a.ndim
shapes.append(a.shape)
Expand Down
23 changes: 18 additions & 5 deletions python/ray/data/_internal/numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def is_valid_udf_return(udf_return_col: Any) -> bool:
return isinstance(udf_return_col, list) or is_array_like(udf_return_col)


def is_scalar_list(udf_return_col: Any) -> bool:
"""Check whether a UDF column is is a scalar list."""

return isinstance(udf_return_col, list) and (
not udf_return_col or np.isscalar(udf_return_col[0])
)


def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any:
"""Convert UDF columns (output of map_batches) to numpy, if possible.
Expand Down Expand Up @@ -55,16 +63,21 @@ def convert_udf_returns_to_numpy(udf_return_col: Any) -> Any:
# `str` are also Iterable.
try:
# Try to cast the inner scalars to numpy as well, to avoid unnecessarily
# creating an inefficient array of array of object dtype.
if all(is_valid_udf_return(e) for e in udf_return_col):
# creating an inefficient array of array of object dtype. Don't convert
# scalar lists though, since those can be represented as pyarrow list type
# without needing to go through our tensor extension.
if all(
is_valid_udf_return(e) and not is_scalar_list(e) for e in udf_return_col
):
udf_return_col = [np.array(e) for e in udf_return_col]
shapes = set()
has_object = False
for e in udf_return_col:
if isinstance(e, np.ndarray):
shapes.add((e.dtype, e.shape))
else:
shapes.add(type(e))
if len(shapes) > 1:
elif not np.isscalar(e):
has_object = True
if has_object or len(shapes) > 1:
# This util works around some limitations of np.array(dtype=object).
udf_return_col = create_ragged_ndarray(udf_return_col)
else:
Expand Down
2 changes: 0 additions & 2 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def _table_from_pydict(columns: Dict[str, List[Any]]) -> "pandas.DataFrame":
):
from ray.data.extensions.tensor_extension import TensorArray

if len(value) == 1:
value = value[0]
columns[key] = TensorArray(value)
return pandas.DataFrame(columns)

Expand Down
9 changes: 6 additions & 3 deletions python/ray/data/_internal/table_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ray.data.block import Block, BlockAccessor
from ray.data.row import TableRow
from ray.data._internal.block_builder import BlockBuilder
from ray.data._internal.numpy_support import is_array_like
from ray.data._internal.numpy_support import is_array_like, convert_udf_returns_to_numpy
from ray.data._internal.size_estimator import SizeEstimator
from ray.data._internal.util import _is_tensor_schema

Expand Down Expand Up @@ -118,8 +118,11 @@ def will_build_yield_copy(self) -> bool:
return self._concat_would_copy() and len(self._tables) > 1

def build(self) -> Block:
if self._columns:
tables = [self._table_from_pydict(self._columns)]
columns = {
key: convert_udf_returns_to_numpy(col) for key, col in self._columns.items()
}
if columns:
tables = [self._table_from_pydict(columns)]
else:
tables = []
tables.extend(self._tables)
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/tests/test_all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,9 @@ def _to_arrow(ds):
nan_agg_ds = nan_grouped_ds.std("B", ignore_nulls=False)
assert nan_agg_ds.count() == 3
result = nan_agg_ds.to_pandas()["std(B)"].to_numpy()
expected = nan_df.groupby("A")["B"].std()
expected[0] = None
np.testing.assert_array_almost_equal(result, expected)
expected = nan_df.groupby("A")["B"].std().to_numpy()
assert result[0] is None or np.isnan(result[0])
np.testing.assert_array_almost_equal(result[1:], expected[1:])
# Test all nans
nan_df = pd.DataFrame({"A": [x % 3 for x in xs], "B": [None] * len(xs)})
ds = ray.data.from_pandas(nan_df).repartition(num_parts)
Expand Down
40 changes: 31 additions & 9 deletions python/ray/data/tests/test_numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,42 +71,56 @@ def test_ragged_array_like(ray_start_regular_shared):
output, np.array([np.array([1, 2, 3]), np.array([1, 2])], dtype=object)
)


def test_ragged_lists(ray_start_regular_shared):
data = [[1, 2, 3], [1, 2]]
data = [torch.zeros((3, 5, 10)), torch.zeros((3, 8, 8))]
output = do_map_batches(data)
assert_structure_equals(
output, np.array([np.array([1, 2, 3]), np.array([1, 2])], dtype=object)
output, create_ragged_ndarray([np.zeros((3, 5, 10)), np.zeros((3, 8, 8))])
)


def test_scalar_nested_arrays(ray_start_regular_shared):
data = [[[1]], [[2]]]
output = do_map_batches(data)
assert_structure_equals(output, np.array([[[1]], [[2]]]))


def test_scalar_lists_not_converted(ray_start_regular_shared):
data = [[1, 2], [1, 2]]
output = do_map_batches(data)
assert_structure_equals(output, create_ragged_ndarray([[1, 2], [1, 2]]))

data = [[1, 2, 3], [1, 2]]
output = do_map_batches(data)
assert_structure_equals(output, create_ragged_ndarray([[1, 2, 3], [1, 2]]))


def test_scalar_numpy(ray_start_regular_shared):
data = np.int64(1)
ds = ray.data.range(2)
ds = ray.data.range(2, parallelism=1)
ds = ds.map(lambda x: {"output": data})
output = ds.take_batch()["output"]
assert_structure_equals(output, np.array([1, 1], dtype=np.int64))


def test_scalar_arrays(ray_start_regular_shared):
data = np.array([1, 2, 3])
ds = ray.data.range(2)
ds = ray.data.range(2, parallelism=1)
ds = ds.map(lambda x: {"output": data})
output = ds.take_batch()["output"]
assert_structure_equals(output, np.array([[1, 2, 3], [1, 2, 3]], dtype=np.int64))


def test_scalar_array_like(ray_start_regular_shared):
data = torch.Tensor([1, 2, 3])
ds = ray.data.range(2)
ds = ray.data.range(2, parallelism=1)
ds = ds.map(lambda x: {"output": data})
output = ds.take_batch()["output"]
assert_structure_equals(output, np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32))


def test_scalar_ragged_arrays(ray_start_regular_shared):
data = [np.array([1, 2, 3]), np.array([1, 2])]
ds = ray.data.range(2)
ds = ray.data.range(2, parallelism=1)
ds = ds.map(lambda x: {"output": data[x["id"]]})
output = ds.take_batch()["output"]
assert_structure_equals(
Expand All @@ -116,13 +130,21 @@ def test_scalar_ragged_arrays(ray_start_regular_shared):

def test_scalar_ragged_array_like(ray_start_regular_shared):
data = [torch.Tensor([1, 2, 3]), torch.Tensor([1, 2])]
ds = ray.data.range(2)
ds = ray.data.range(2, parallelism=1)
ds = ds.map(lambda x: {"output": data[x["id"]]})
output = ds.take_batch()["output"]
assert_structure_equals(
output, np.array([np.array([1, 2, 3]), np.array([1, 2])], dtype=object)
)

data = [torch.zeros((3, 5, 10)), torch.zeros((3, 8, 8))]
ds = ray.data.range(2, parallelism=1)
ds = ds.map(lambda x: {"output": data[x["id"]]})
output = ds.take_batch()["output"]
assert_structure_equals(
output, create_ragged_ndarray([np.zeros((3, 5, 10)), np.zeros((3, 8, 8))])
)


# https://github.com/ray-project/ray/issues/35340
def test_complex_ragged_arrays(ray_start_regular_shared):
Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/tests/test_strict_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ def test_strict_schema(ray_start_regular_shared, enable_strict_mode):
ds = ray.data.from_items([{"x": 2, "y": object(), "z": [1, 2]}])
schema = ds.schema()
assert schema.names == ["x", "y", "z"]
assert schema.types == [pa.int64(), object, object]
assert schema.types == [
pa.int64(),
object,
object,
]

ds = ray.data.from_numpy(np.ones((100, 10)))
schema = ds.schema()
Expand Down

0 comments on commit 1447eb7

Please sign in to comment.