Skip to content

Commit

Permalink
[Datasets] Convert between block and batch correctly for map_groups (#…
Browse files Browse the repository at this point in the history
…30172)

This is to fix issue found in #30102, where user can do ds.groupby("key").map_groups(fn, batch_format="numpy"). We need to correctly convert between block and batch in map_groups to handle it.
  • Loading branch information
c21 authored Nov 14, 2022
1 parent 54c532f commit 82ccb15
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
14 changes: 10 additions & 4 deletions python/ray/data/grouped_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,17 +341,23 @@ def get_key_boundaries(block_accessor: BlockAccessor):
# The batch is the entire block, because we have batch_size=None for
# map_batches() below.
def group_fn(batch):
block_accessor = BlockAccessor.for_block(batch)
block = BlockAccessor.batch_to_block(batch)
block_accessor = BlockAccessor.for_block(block)
if self._key:
boundaries = get_key_boundaries(block_accessor)
else:
boundaries = [block_accessor.num_rows()]
builder = DelegatingBlockBuilder()
start = 0
for end in boundaries:
group = block_accessor.slice(start, end)
applied = fn(group)
builder.add_block(applied)
group_block = block_accessor.slice(start, end)
group_block_accessor = BlockAccessor.for_block(group_block)
# Convert block of each group to batch format here, because the
# block format here can be different from batch format
# (e.g. block is Arrow format, and batch is NumPy format).
group_batch = group_block_accessor.to_batch_format(batch_format)
applied = fn(group_batch)
builder.add_batch(applied)
start = end
rs = builder.build()
return rs
Expand Down
20 changes: 20 additions & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4200,6 +4200,26 @@ def normalize(at: pa.Table):
assert result.equals(expected)


def test_groupby_map_groups_for_numpy(ray_start_regular_shared):
ds = ray.data.from_items(
[
{"group": 1, "value": 1},
{"group": 1, "value": 2},
{"group": 2, "value": 3},
{"group": 2, "value": 4},
]
)

def func(group):
# Test output type is NumPy format.
return {"group": group["group"] + 1, "value": group["value"] + 1}

ds = ds.groupby("group").map_groups(func, batch_format="numpy")
expected = pa.Table.from_pydict({"group": [2, 2, 3, 3], "value": [2, 3, 4, 5]})
result = pa.Table.from_pandas(ds.to_pandas())
assert result.equals(expected)


def test_groupby_map_groups_with_different_types(ray_start_regular_shared):
ds = ray.data.from_items(
[
Expand Down

0 comments on commit 82ccb15

Please sign in to comment.