From 2ff8fcc7c297c281ab323adafabf01cdce519e28 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 13 Dec 2022 17:20:15 -0800 Subject: [PATCH] Always detect Ray Dataset as distributed (#253) Ensures that we always use distributed loading by default with Ray Datasets. Followup to ray-project/ray#31079 Signed-off-by: Antoni Baum --- xgboost_ray/matrix.py | 6 +++++- xgboost_ray/tests/test_matrix.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/xgboost_ray/matrix.py b/xgboost_ray/matrix.py index 10b26fa0..e61d8533 100644 --- a/xgboost_ray/matrix.py +++ b/xgboost_ray/matrix.py @@ -24,7 +24,7 @@ try: from ray.data.dataset import Dataset as RayDataset -except (ImportError, ModuleNotFoundError): +except ImportError: class RayDataset: pass @@ -916,6 +916,8 @@ def _can_load_distributed(source: Data) -> bool: return False elif Modin.is_data_type(source): return True + elif isinstance(source, RayDataset): + return True elif isinstance(source, str): # Strings should point to files or URLs # Usually parquet files point to directories @@ -940,6 +942,8 @@ def _detect_distributed(source: Data) -> bool: return False if Modin.is_data_type(source): return True + if isinstance(source, RayDataset): + return True if isinstance(source, Iterable) and not isinstance(source, str) and \ not (isinstance(source, Sequence) and isinstance(source[0], str)): # This is an iterable but not a Sequence of strings, and not a diff --git a/xgboost_ray/tests/test_matrix.py b/xgboost_ray/tests/test_matrix.py index e67e3fbe..f312cac0 100644 --- a/xgboost_ray/tests/test_matrix.py +++ b/xgboost_ray/tests/test_matrix.py @@ -8,6 +8,11 @@ import pandas as pd import ray +try: + import ray.data as ray_data +except (ImportError, ModuleNotFoundError): + + ray_data = None from xgboost_ray import RayDMatrix from xgboost_ray.matrix import (concat_dataframes, RayShardingMode, @@ -29,7 +34,7 @@ def setUp(self): @classmethod def setUpClass(cls): - ray.init(local_mode=True) + ray.init() @classmethod def tearDownClass(cls): @@ -315,6 +320,11 @@ def testDetectDistributed(self): mat = RayDMatrix([csv_file] * 3, lazy=True) self.assertTrue(mat.distributed) + if ray_data: + ds = ray_data.read_parquet(parquet_file) + mat = RayDMatrix(ds) + self.assertTrue(mat.distributed) + def testTooManyActorsDistributed(self): """Test error when too many actors are passed""" with self.assertRaises(RuntimeError):