Skip to content

Commit

Permalink
Always detect Ray Dataset as distributed (#253)
Browse files Browse the repository at this point in the history
Ensures that we always use distributed loading by default with Ray Datasets. Followup to ray-project/ray#31079

Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 authored Dec 14, 2022
1 parent 5f016ff commit 2ff8fcc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
6 changes: 5 additions & 1 deletion xgboost_ray/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

try:
from ray.data.dataset import Dataset as RayDataset
except (ImportError, ModuleNotFoundError):
except ImportError:

class RayDataset:
pass
Expand Down Expand Up @@ -916,6 +916,8 @@ def _can_load_distributed(source: Data) -> bool:
return False
elif Modin.is_data_type(source):
return True
elif isinstance(source, RayDataset):
return True
elif isinstance(source, str):
# Strings should point to files or URLs
# Usually parquet files point to directories
Expand All @@ -940,6 +942,8 @@ def _detect_distributed(source: Data) -> bool:
return False
if Modin.is_data_type(source):
return True
if isinstance(source, RayDataset):
return True
if isinstance(source, Iterable) and not isinstance(source, str) and \
not (isinstance(source, Sequence) and isinstance(source[0], str)):
# This is an iterable but not a Sequence of strings, and not a
Expand Down
12 changes: 11 additions & 1 deletion xgboost_ray/tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import pandas as pd

import ray
try:
import ray.data as ray_data
except (ImportError, ModuleNotFoundError):

ray_data = None

from xgboost_ray import RayDMatrix
from xgboost_ray.matrix import (concat_dataframes, RayShardingMode,
Expand All @@ -29,7 +34,7 @@ def setUp(self):

@classmethod
def setUpClass(cls):
ray.init(local_mode=True)
ray.init()

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -315,6 +320,11 @@ def testDetectDistributed(self):
mat = RayDMatrix([csv_file] * 3, lazy=True)
self.assertTrue(mat.distributed)

if ray_data:
ds = ray_data.read_parquet(parquet_file)
mat = RayDMatrix(ds)
self.assertTrue(mat.distributed)

def testTooManyActorsDistributed(self):
"""Test error when too many actors are passed"""
with self.assertRaises(RuntimeError):
Expand Down

0 comments on commit 2ff8fcc

Please sign in to comment.