Skip to content

Commit

Permalink
[Train] Force GBDTTrainer to use distributed loading for Ray Datasets (
Browse files Browse the repository at this point in the history
…ray-project#31079)

Signed-off-by: amogkam <[email protected]>

Closes ray-project#31068

xgboost_ray has 2 modes for data loading:

A centralized mode where the driver first loads in all the data and then partitions it for the remote training actors to load.
A distributed mode where the remote training actors load in the data partitions directly.
When using Ray Datasets with xgboost_ray, we should always do distributed data loading (option 2). However, this is no longer the case after ray-project#30575 is merged.

ray-project#30575 adds an __iter__ method to Ray Datasets causing isinstance(dataset, Iterable) to return True.

This causes Ray Dataset inputs to enter this if statement: https://github.com/ray-project/xgboost_ray/blob/v0.1.12/xgboost_ray/matrix.py#L943-L949, causing xgboost-ray to think that Ray Datasets are not distributed and therefore going with option 1 for loading.

This centralized loading leads to excessive object spilling and ultimately crashes large scale xgboost training.

In this PR, we force distributed data loading when using the AIR GBDTTrainers.

In a follow up, we should clean up the distributed detection logic directly in xgboost-ray, removing input formats that are no longer supported, and then do a new release.

Signed-off-by: tmynn <[email protected]>
  • Loading branch information
amogkam authored and tamohannes committed Jan 25, 2023
1 parent 0719d61 commit dc3deed
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
10 changes: 9 additions & 1 deletion python/ray/train/gbdt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,10 @@ def __init__(
):
self.label_column = label_column
self.params = params
self.dmatrix_params = dmatrix_params or {}

self.train_kwargs = train_kwargs
self.dmatrix_params = dmatrix_params or {}

super().__init__(
scaling_config=scaling_config,
run_config=run_config,
Expand All @@ -168,6 +170,12 @@ def __init__(
resume_from_checkpoint=resume_from_checkpoint,
)

# Ray Datasets should always use distributed loading.
for dataset_name in self.datasets.keys():
dataset_params = self.dmatrix_params.get(dataset_name, {})
dataset_params["distributed"] = True
self.dmatrix_params[dataset_name] = dataset_params

def _validate_attributes(self):
super()._validate_attributes()
self._validate_config_and_datasets()
Expand Down
21 changes: 21 additions & 0 deletions python/ray/train/tests/test_xgboost_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,27 @@ def test_validation(ray_start_4_cpus):
)


def test_distributed_data_loading(ray_start_4_cpus):
"""Checks that XGBoostTrainer does distributed data loading for Ray Datasets."""

class DummyXGBoostTrainer(XGBoostTrainer):
def _train(self, params, dtrain, **kwargs):
assert dtrain.distributed
return super()._train(params=params, dtrain=dtrain, **kwargs)

train_dataset = ray.data.from_pandas(train_df)

trainer = DummyXGBoostTrainer(
scaling_config=ScalingConfig(num_workers=2),
label_column="target",
params=params,
datasets={TRAIN_DATASET_KEY: train_dataset},
)

assert trainer.dmatrix_params[TRAIN_DATASET_KEY]["distributed"]
trainer.fit()


if __name__ == "__main__":
import pytest
import sys
Expand Down

0 comments on commit dc3deed

Please sign in to comment.