diff --git a/xgboost_ray/matrix.py b/xgboost_ray/matrix.py index 10b26fa0..e61d8533 100644 --- a/xgboost_ray/matrix.py +++ b/xgboost_ray/matrix.py @@ -24,7 +24,7 @@ try: from ray.data.dataset import Dataset as RayDataset -except (ImportError, ModuleNotFoundError): +except ImportError: class RayDataset: pass @@ -916,6 +916,8 @@ def _can_load_distributed(source: Data) -> bool: return False elif Modin.is_data_type(source): return True + elif isinstance(source, RayDataset): + return True elif isinstance(source, str): # Strings should point to files or URLs # Usually parquet files point to directories @@ -940,6 +942,8 @@ def _detect_distributed(source: Data) -> bool: return False if Modin.is_data_type(source): return True + if isinstance(source, RayDataset): + return True if isinstance(source, Iterable) and not isinstance(source, str) and \ not (isinstance(source, Sequence) and isinstance(source[0], str)): # This is an iterable but not a Sequence of strings, and not a diff --git a/xgboost_ray/tests/test_matrix.py b/xgboost_ray/tests/test_matrix.py index e67e3fbe..f312cac0 100644 --- a/xgboost_ray/tests/test_matrix.py +++ b/xgboost_ray/tests/test_matrix.py @@ -8,6 +8,11 @@ import pandas as pd import ray +try: + import ray.data as ray_data +except (ImportError, ModuleNotFoundError): + + ray_data = None from xgboost_ray import RayDMatrix from xgboost_ray.matrix import (concat_dataframes, RayShardingMode, @@ -29,7 +34,7 @@ def setUp(self): @classmethod def setUpClass(cls): - ray.init(local_mode=True) + ray.init() @classmethod def tearDownClass(cls): @@ -315,6 +320,11 @@ def testDetectDistributed(self): mat = RayDMatrix([csv_file] * 3, lazy=True) self.assertTrue(mat.distributed) + if ray_data: + ds = ray_data.read_parquet(parquet_file) + mat = RayDMatrix(ds) + self.assertTrue(mat.distributed) + def testTooManyActorsDistributed(self): """Test error when too many actors are passed""" with self.assertRaises(RuntimeError):