diff --git a/python/ray/train/gbdt_trainer.py b/python/ray/train/gbdt_trainer.py index debdb526ae9a..48a19866e011 100644 --- a/python/ray/train/gbdt_trainer.py +++ b/python/ray/train/gbdt_trainer.py @@ -158,8 +158,10 @@ def __init__( ): self.label_column = label_column self.params = params - self.dmatrix_params = dmatrix_params or {} + self.train_kwargs = train_kwargs + self.dmatrix_params = dmatrix_params or {} + super().__init__( scaling_config=scaling_config, run_config=run_config, @@ -168,6 +170,12 @@ def __init__( resume_from_checkpoint=resume_from_checkpoint, ) + # Ray Datasets should always use distributed loading. + for dataset_name in self.datasets.keys(): + dataset_params = self.dmatrix_params.get(dataset_name, {}) + dataset_params["distributed"] = True + self.dmatrix_params[dataset_name] = dataset_params + def _validate_attributes(self): super()._validate_attributes() self._validate_config_and_datasets() diff --git a/python/ray/train/tests/test_xgboost_trainer.py b/python/ray/train/tests/test_xgboost_trainer.py index 38f8b1fc130e..931515e207ad 100644 --- a/python/ray/train/tests/test_xgboost_trainer.py +++ b/python/ray/train/tests/test_xgboost_trainer.py @@ -237,6 +237,27 @@ def test_validation(ray_start_4_cpus): ) +def test_distributed_data_loading(ray_start_4_cpus): + """Checks that XGBoostTrainer does distributed data loading for Ray Datasets.""" + + class DummyXGBoostTrainer(XGBoostTrainer): + def _train(self, params, dtrain, **kwargs): + assert dtrain.distributed + return super()._train(params=params, dtrain=dtrain, **kwargs) + + train_dataset = ray.data.from_pandas(train_df) + + trainer = DummyXGBoostTrainer( + scaling_config=ScalingConfig(num_workers=2), + label_column="target", + params=params, + datasets={TRAIN_DATASET_KEY: train_dataset}, + ) + + assert trainer.dmatrix_params[TRAIN_DATASET_KEY]["distributed"] + trainer.fit() + + if __name__ == "__main__": import pytest import sys