Skip to content

Commit

Permalink
[train] Updates to support xgboost==2.1.0 (#46667)
Browse files Browse the repository at this point in the history
Support xgboost 2.1.0, which was recently released and changed some of the
distributed setup APIs.
---------

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Aug 8, 2024
1 parent 2f21bcf commit c634872
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 10 deletions.
109 changes: 102 additions & 7 deletions python/ray/train/xgboost/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import json
import logging
import os
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional

import xgboost
from packaging.version import Version
from xgboost import RabitTracker
from xgboost.collective import CommunicatorContext

Expand Down Expand Up @@ -37,28 +41,93 @@ class XGBoostConfig(BackendConfig):
def train_func_context(self):
@contextmanager
def collective_communication_context():
with CommunicatorContext():
with CommunicatorContext(**_get_xgboost_args()):
yield

return collective_communication_context

@property
def backend_cls(self):
if self.xgboost_communicator == "rabit":
return _XGBoostRabitBackend
return (
_XGBoostRabitBackend
if Version(xgboost.__version__) >= Version("2.1.0")
else _XGBoostRabitBackend_pre_xgb210
)

raise NotImplementedError(f"Unsupported backend: {self.xgboost_communicator}")


class _XGBoostRabitBackend(Backend):
def __init__(self):
self._tracker = None
self._tracker: Optional[RabitTracker] = None
self._wait_thread: Optional[threading.Thread] = None

def _setup_xgboost_distributed_backend(self, worker_group: WorkerGroup):
# Set up the rabit tracker on the Train driver.
num_workers = len(worker_group)
rabit_args = {"n_workers": num_workers}
train_driver_ip = ray.util.get_node_ip_address()

# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
# align with Ray Train worker ranks.
# The worker ranks will be sorted by `dmlc_task_id`,
# which is defined below.
self._tracker = RabitTracker(
n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
)
self._tracker.start()

# The RabitTracker is started in a separate thread, and the
# `wait_for` method must be called for `worker_args` to return.
self._wait_thread = threading.Thread(target=self._tracker.wait_for, daemon=True)
self._wait_thread.start()

rabit_args.update(self._tracker.worker_args())

start_log = (
"RabitTracker coordinator started with parameters:\n"
f"{json.dumps(rabit_args, indent=2)}"
)
logger.debug(start_log)

def set_xgboost_communicator_args(args):
import ray.train

args["dmlc_task_id"] = (
f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
f"{ray.get_runtime_context().get_actor_id()}"
)

_set_xgboost_args(args)

worker_group.execute(set_xgboost_communicator_args, rabit_args)

def on_training_start(
self, worker_group: WorkerGroup, backend_config: XGBoostConfig
):
assert backend_config.xgboost_communicator == "rabit"
self._setup_xgboost_distributed_backend(worker_group)

def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
timeout = 5

if self._wait_thread is not None:
self._wait_thread.join(timeout=timeout)

if self._wait_thread.is_alive():
logger.warning(
"During shutdown, the RabitTracker thread failed to join "
f"within {timeout} seconds. "
"The process will still be terminated as part of Ray actor cleanup."
)


class _XGBoostRabitBackend_pre_xgb210(Backend):
def __init__(self):
self._tracker: Optional[RabitTracker] = None

def _setup_xgboost_distributed_backend(self, worker_group: WorkerGroup):
# Set up the rabit tracker on the Train driver.
num_workers = len(worker_group)
rabit_args = {"DMLC_NUM_WORKER": num_workers}
Expand All @@ -67,12 +136,14 @@ def on_training_start(
# NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
# align with Ray Train worker ranks.
# The worker ranks will be sorted by `DMLC_TASK_ID`,
# which is defined in `on_training_start`.
# which is defined below.
self._tracker = RabitTracker(
host_ip=train_driver_ip, n_workers=num_workers, sortby="task"
n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
)
rabit_args.update(self._tracker.worker_envs())
self._tracker.start(num_workers)
self._tracker.start(n_workers=num_workers)

worker_args = self._tracker.worker_envs()
rabit_args.update(worker_args)

start_log = (
"RabitTracker coordinator started with parameters:\n"
Expand All @@ -95,7 +166,16 @@ def set_xgboost_env_vars():

worker_group.execute(set_xgboost_env_vars)

def on_training_start(
self, worker_group: WorkerGroup, backend_config: XGBoostConfig
):
assert backend_config.xgboost_communicator == "rabit"
self._setup_xgboost_distributed_backend(worker_group)

def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
if not self._tracker:
return

timeout = 5
self._tracker.thread.join(timeout=timeout)

Expand All @@ -105,3 +185,18 @@ def on_shutdown(self, worker_group: WorkerGroup, backend_config: XGBoostConfig):
f"within {timeout} seconds. "
"The process will still be terminated as part of Ray actor cleanup."
)


_xgboost_args: dict = {}
_xgboost_args_lock = threading.Lock()


def _set_xgboost_args(args):
with _xgboost_args_lock:
global _xgboost_args
_xgboost_args = args


def _get_xgboost_args() -> dict:
with _xgboost_args_lock:
return _xgboost_args
2 changes: 1 addition & 1 deletion python/requirements/ml/core-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mlflow==2.9.2
wandb==0.17.0

# ML training frameworks
xgboost==1.7.6
xgboost==2.1.0
lightgbm==3.3.5

# Huggingface
Expand Down
8 changes: 6 additions & 2 deletions python/requirements_compiled.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,10 @@ numpy==1.26.4
# webdataset
# xgboost
# zoopt
nvidia-nccl-cu12==2.20.5
# via
# -r /ray/ci/../python/requirements/ml/core-requirements.txt
# xgboost
oauth2client==4.1.3
# via
# gcs-oauth2-boto-plugin
Expand Down Expand Up @@ -2504,7 +2508,7 @@ wrapt==1.14.1
# tensorflow-datasets
wurlitzer==3.1.1
# via comet-ml
xgboost==1.7.6
xgboost==2.1.0
# via -r /ray/ci/../python/requirements/ml/core-requirements.txt
xlrd==2.0.1
# via -r /ray/ci/../python/requirements/test-requirements.txt
Expand Down Expand Up @@ -2544,4 +2548,4 @@ zoopt==0.4.1

# The following packages are considered to be unsafe in a requirements file:
# pip
# setuptools
# setuptools

0 comments on commit c634872

Please sign in to comment.