Skip to content

Commit

Permalink
[Datasets] Fix byte size calculation for non-trivial tensors (#25264)
Browse files Browse the repository at this point in the history
The range datasource was incorrectly calculating tensor sizes if the dimensions != (1,).

Broken out from https://github.com/ray-project/ray/pull/25167/files
  • Loading branch information
ericl authored May 31, 2022
1 parent 65f908e commit c93e37a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,13 @@ def make_block(start: int, count: int) -> Block:
schema = int
else:
raise ValueError("Unsupported block type", block_format)
if block_format == "tensor":
element_size = np.product(tensor_shape)
else:
element_size = 1
meta = BlockMetadata(
num_rows=count,
size_bytes=8 * count,
size_bytes=8 * count * element_size,
schema=schema,
input_files=None,
exec_stats=None,
Expand Down
1 change: 1 addition & 0 deletions python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def test_tensors(ray_start_regular_shared):
"Dataset(num_blocks=5, num_rows=5, "
"schema={value: <ArrowTensorType: shape=(3, 5), dtype=int64>})"
)
assert ds.size_bytes() == 5 * 3 * 5 * 8

# Pandas conversion.
res = (
Expand Down

0 comments on commit c93e37a

Please sign in to comment.