Skip to content

Commit

Permalink
[Datasets] [Hotfix] Fix stats construction for from_* APIs. (#25601)
Browse files Browse the repository at this point in the history
Stats construction on the from_arrow and from_numpy (and from_pandas with Pandas block support disabled) is currently broken since we weren't resolving the block metadata before passing it to the stats, causing future ds.stats() calls to fail. This PR fixes this and adds some test coverage.

Drivebys:

- Adds stats for from_pandas() zero-copy path (metadata fetch only).
- Changes "from_numpy" stats stage name to "from_numpy_refs", to be consistent with stats for other from_*() APIs.
  • Loading branch information
clarkzinzow authored Jun 9, 2022
1 parent f3c2bd6 commit 6987ab5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
23 changes: 14 additions & 9 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,20 +821,24 @@ def from_pandas_refs(
context = DatasetContext.get_current()
if context.enable_pandas_block:
get_metadata = cached_remote_fn(_get_metadata)
metadata = [get_metadata.remote(df) for df in dfs]
metadata = ray.get([get_metadata.remote(df) for df in dfs])
return Dataset(
ExecutionPlan(BlockList(dfs, ray.get(metadata)), DatasetStats.TODO()),
ExecutionPlan(
BlockList(dfs, metadata),
DatasetStats(stages={"from_pandas_refs": metadata}, parent=None),
),
0,
False,
)

df_to_block = cached_remote_fn(_df_to_block, num_returns=2)

res = [df_to_block.remote(df) for df in dfs]
blocks, metadata = zip(*res)
blocks, metadata = map(list, zip(*res))
metadata = ray.get(metadata)
return Dataset(
ExecutionPlan(
BlockList(blocks, ray.get(list(metadata))),
BlockList(blocks, metadata),
DatasetStats(stages={"from_pandas_refs": metadata}, parent=None),
),
0,
Expand Down Expand Up @@ -888,11 +892,12 @@ def from_numpy_refs(
ndarray_to_block = cached_remote_fn(_ndarray_to_block, num_returns=2)

res = [ndarray_to_block.remote(ndarray) for ndarray in ndarrays]
blocks, metadata = zip(*res)
blocks, metadata = map(list, zip(*res))
metadata = ray.get(metadata)
return Dataset(
ExecutionPlan(
BlockList(blocks, ray.get(list(metadata))),
DatasetStats(stages={"from_numpy": metadata}, parent=None),
BlockList(blocks, metadata),
DatasetStats(stages={"from_numpy_refs": metadata}, parent=None),
),
0,
False,
Expand Down Expand Up @@ -939,10 +944,10 @@ def from_arrow_refs(
tables = [tables]

get_metadata = cached_remote_fn(_get_metadata)
metadata = [get_metadata.remote(t) for t in tables]
metadata = ray.get([get_metadata.remote(t) for t in tables])
return Dataset(
ExecutionPlan(
BlockList(tables, ray.get(metadata)),
BlockList(tables, metadata),
DatasetStats(stages={"from_arrow_refs": metadata}, parent=None),
),
0,
Expand Down
20 changes: 20 additions & 0 deletions python/ray/data/tests/test_dataset_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,17 @@ def test_from_pandas(ray_start_regular_shared, enable_pandas_block):
values = [(r["one"], r["two"]) for r in ds.take(6)]
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
assert values == rows
# Check that metadata fetch is included in stats.
assert "from_pandas_refs" in ds.stats()

# test from single pandas dataframe
ds = ray.data.from_pandas(df1)
assert ds._dataset_format() == "pandas" if enable_pandas_block else "arrow"
values = [(r["one"], r["two"]) for r in ds.take(3)]
rows = [(r.one, r.two) for _, r in df1.iterrows()]
assert values == rows
# Check that metadata fetch is included in stats.
assert "from_pandas_refs" in ds.stats()
finally:
ctx.enable_pandas_block = old_enable_pandas_block

Expand All @@ -101,13 +105,17 @@ def test_from_pandas_refs(ray_start_regular_shared, enable_pandas_block):
values = [(r["one"], r["two"]) for r in ds.take(6)]
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
assert values == rows
# Check that metadata fetch is included in stats.
assert "from_pandas_refs" in ds.stats()

# test from single pandas dataframe ref
ds = ray.data.from_pandas_refs(ray.put(df1))
assert ds._dataset_format() == "pandas" if enable_pandas_block else "arrow"
values = [(r["one"], r["two"]) for r in ds.take(3)]
rows = [(r.one, r.two) for _, r in df1.iterrows()]
assert values == rows
# Check that metadata fetch is included in stats.
assert "from_pandas_refs" in ds.stats()
finally:
ctx.enable_pandas_block = old_enable_pandas_block

Expand All @@ -123,6 +131,8 @@ def test_from_numpy(ray_start_regular_shared, from_ref):
ds = ray.data.from_numpy(arrs)
values = np.stack(ds.take(8))
np.testing.assert_array_equal(values, np.concatenate((arr1, arr2)))
# Check that conversion task is included in stats.
assert "from_numpy_refs" in ds.stats()

# Test from single NumPy ndarray.
if from_ref:
Expand All @@ -131,6 +141,8 @@ def test_from_numpy(ray_start_regular_shared, from_ref):
ds = ray.data.from_numpy(arr1)
values = np.stack(ds.take(4))
np.testing.assert_array_equal(values, arr1)
# Check that conversion task is included in stats.
assert "from_numpy_refs" in ds.stats()


def test_from_arrow(ray_start_regular_shared):
Expand All @@ -140,12 +152,16 @@ def test_from_arrow(ray_start_regular_shared):
values = [(r["one"], r["two"]) for r in ds.take(6)]
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
assert values == rows
# Check that metadata fetch is included in stats.
assert "from_arrow_refs" in ds.stats()

# test from single pyarrow table
ds = ray.data.from_arrow(pa.Table.from_pandas(df1))
values = [(r["one"], r["two"]) for r in ds.take(3)]
rows = [(r.one, r.two) for _, r in df1.iterrows()]
assert values == rows
# Check that metadata fetch is included in stats.
assert "from_arrow_refs" in ds.stats()


def test_from_arrow_refs(ray_start_regular_shared):
Expand All @@ -157,12 +173,16 @@ def test_from_arrow_refs(ray_start_regular_shared):
values = [(r["one"], r["two"]) for r in ds.take(6)]
rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()]
assert values == rows
# Check that metadata fetch is included in stats.
assert "from_arrow_refs" in ds.stats()

# test from single pyarrow table ref
ds = ray.data.from_arrow_refs(ray.put(pa.Table.from_pandas(df1)))
values = [(r["one"], r["two"]) for r in ds.take(3)]
rows = [(r.one, r.two) for _, r in df1.iterrows()]
assert values == rows
# Check that metadata fetch is included in stats.
assert "from_arrow_refs" in ds.stats()


def test_to_pandas(ray_start_regular_shared):
Expand Down

0 comments on commit 6987ab5

Please sign in to comment.