From d82ee25054b90ce5416f98642d86bade4fe7a1fe Mon Sep 17 00:00:00 2001 From: Clark Zinzow Date: Wed, 8 Jun 2022 20:17:58 +0000 Subject: [PATCH] Fix stats construction for from_* APIs. --- python/ray/data/read_api.py | 23 +++++++++++-------- python/ray/data/tests/test_dataset_formats.py | 20 ++++++++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index c03ba94f55fc..fe3e282392f7 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -821,9 +821,12 @@ def from_pandas_refs( context = DatasetContext.get_current() if context.enable_pandas_block: get_metadata = cached_remote_fn(_get_metadata) - metadata = [get_metadata.remote(df) for df in dfs] + metadata = ray.get([get_metadata.remote(df) for df in dfs]) return Dataset( - ExecutionPlan(BlockList(dfs, ray.get(metadata)), DatasetStats.TODO()), + ExecutionPlan( + BlockList(dfs, metadata), + DatasetStats(stages={"from_pandas_refs": metadata}, parent=None), + ), 0, False, ) @@ -831,10 +834,11 @@ def from_pandas_refs( df_to_block = cached_remote_fn(_df_to_block, num_returns=2) res = [df_to_block.remote(df) for df in dfs] - blocks, metadata = zip(*res) + blocks, metadata = map(list, zip(*res)) + metadata = ray.get(metadata) return Dataset( ExecutionPlan( - BlockList(blocks, ray.get(list(metadata))), + BlockList(blocks, metadata), DatasetStats(stages={"from_pandas_refs": metadata}, parent=None), ), 0, @@ -888,11 +892,12 @@ def from_numpy_refs( ndarray_to_block = cached_remote_fn(_ndarray_to_block, num_returns=2) res = [ndarray_to_block.remote(ndarray) for ndarray in ndarrays] - blocks, metadata = zip(*res) + blocks, metadata = map(list, zip(*res)) + metadata = ray.get(metadata) return Dataset( ExecutionPlan( - BlockList(blocks, ray.get(list(metadata))), - DatasetStats(stages={"from_numpy": metadata}, parent=None), + BlockList(blocks, metadata), + DatasetStats(stages={"from_numpy_refs": metadata}, parent=None), ), 0, False, @@ -939,10 +944,10 @@ def from_arrow_refs( tables = [tables] get_metadata = cached_remote_fn(_get_metadata) - metadata = [get_metadata.remote(t) for t in tables] + metadata = ray.get([get_metadata.remote(t) for t in tables]) return Dataset( ExecutionPlan( - BlockList(tables, ray.get(metadata)), + BlockList(tables, metadata), DatasetStats(stages={"from_arrow_refs": metadata}, parent=None), ), 0, diff --git a/python/ray/data/tests/test_dataset_formats.py b/python/ray/data/tests/test_dataset_formats.py index 09a06133ff1e..4863085493e9 100644 --- a/python/ray/data/tests/test_dataset_formats.py +++ b/python/ray/data/tests/test_dataset_formats.py @@ -77,6 +77,8 @@ def test_from_pandas(ray_start_regular_shared, enable_pandas_block): values = [(r["one"], r["two"]) for r in ds.take(6)] rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] assert values == rows + # Check that metadata fetch is included in stats. + assert "from_pandas_refs" in ds.stats() # test from single pandas dataframe ds = ray.data.from_pandas(df1) @@ -84,6 +86,8 @@ def test_from_pandas(ray_start_regular_shared, enable_pandas_block): values = [(r["one"], r["two"]) for r in ds.take(3)] rows = [(r.one, r.two) for _, r in df1.iterrows()] assert values == rows + # Check that metadata fetch is included in stats. + assert "from_pandas_refs" in ds.stats() finally: ctx.enable_pandas_block = old_enable_pandas_block @@ -101,6 +105,8 @@ def test_from_pandas_refs(ray_start_regular_shared, enable_pandas_block): values = [(r["one"], r["two"]) for r in ds.take(6)] rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] assert values == rows + # Check that metadata fetch is included in stats. + assert "from_pandas_refs" in ds.stats() # test from single pandas dataframe ref ds = ray.data.from_pandas_refs(ray.put(df1)) @@ -108,6 +114,8 @@ def test_from_pandas_refs(ray_start_regular_shared, enable_pandas_block): values = [(r["one"], r["two"]) for r in ds.take(3)] rows = [(r.one, r.two) for _, r in df1.iterrows()] assert values == rows + # Check that metadata fetch is included in stats. + assert "from_pandas_refs" in ds.stats() finally: ctx.enable_pandas_block = old_enable_pandas_block @@ -123,6 +131,8 @@ def test_from_numpy(ray_start_regular_shared, from_ref): ds = ray.data.from_numpy(arrs) values = np.stack(ds.take(8)) np.testing.assert_array_equal(values, np.concatenate((arr1, arr2))) + # Check that conversion task is included in stats. + assert "from_numpy_refs" in ds.stats() # Test from single NumPy ndarray. if from_ref: @@ -131,6 +141,8 @@ def test_from_numpy(ray_start_regular_shared, from_ref): ds = ray.data.from_numpy(arr1) values = np.stack(ds.take(4)) np.testing.assert_array_equal(values, arr1) + # Check that conversion task is included in stats. + assert "from_numpy_refs" in ds.stats() def test_from_arrow(ray_start_regular_shared): @@ -140,12 +152,16 @@ def test_from_arrow(ray_start_regular_shared): values = [(r["one"], r["two"]) for r in ds.take(6)] rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] assert values == rows + # Check that metadata fetch is included in stats. + assert "from_arrow_refs" in ds.stats() # test from single pyarrow table ds = ray.data.from_arrow(pa.Table.from_pandas(df1)) values = [(r["one"], r["two"]) for r in ds.take(3)] rows = [(r.one, r.two) for _, r in df1.iterrows()] assert values == rows + # Check that metadata fetch is included in stats. + assert "from_arrow_refs" in ds.stats() def test_from_arrow_refs(ray_start_regular_shared): @@ -157,12 +173,16 @@ def test_from_arrow_refs(ray_start_regular_shared): values = [(r["one"], r["two"]) for r in ds.take(6)] rows = [(r.one, r.two) for _, r in pd.concat([df1, df2]).iterrows()] assert values == rows + # Check that metadata fetch is included in stats. + assert "from_arrow_refs" in ds.stats() # test from single pyarrow table ref ds = ray.data.from_arrow_refs(ray.put(pa.Table.from_pandas(df1))) values = [(r["one"], r["two"]) for r in ds.take(3)] rows = [(r.one, r.two) for _, r in df1.iterrows()] assert values == rows + # Check that metadata fetch is included in stats. + assert "from_arrow_refs" in ds.stats() def test_to_pandas(ray_start_regular_shared):