-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Updates to support xgboost==2.1.0
#46667
Conversation
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…10compat Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…10compat Signed-off-by: Justin Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the implementation! Two questions from me:
- The
RabitTracker
class functionsworker_args
andworker_env
resturn the same type of thingsDict[str, Union[int, str]]
. The only difference is the key ofworker_env
are uppercase letters butworker_args
are lowercase letters. Our adaption to this change is to move from env settup to training context settup, is that correct? - The
RabitTracker
class doesn't maintain athread
itself, instead we need to create a main thread kind of thing using itswait_for
method to wait fortracker.start
by ourselves. My question is: Do we also need to distinguish xgboost version before/post 210 in theon_shutdown
method. My first intuition is to using tracker.thread before 210, and using wait_for after 210. Currently, It seems we always use the wai_for method now.
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Yes. The API changed from accepting environment variables to only allowing you to pass the arguments directly as kwargs with those lower-case names. |
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
def on_training_start( | ||
self, worker_group: WorkerGroup, backend_config: XGBoostConfig | ||
): | ||
assert backend_config.xgboost_communicator == "rabit" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it seems XGBoostingConfig
has a hard coded backend_config
field being "rabit"
, why do we still need an assertion here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I can probably remove this field for now, since we don't support the "federated" option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, most of them are nits and should be non-blocking.
It should be good to go if all unit tests look good.
@@ -37,28 +41,93 @@ class XGBoostConfig(BackendConfig): | |||
def train_func_context(self): | |||
@contextmanager | |||
def collective_communication_context(): | |||
with CommunicatorContext(): | |||
with CommunicatorContext(**_get_xgboost_args()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we able to save the xgboost_args into XGBoost config so we can avoid modifying the global variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, interesting. I actually don't understand why we need both BackendConfig
and Backend
classes. Any context here @matthewdeng ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BackendConfig
is the public API that the user could interact with. There is probably a better/cleaner way to organize the two.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah currently the dependency between BackendConfig
and Backend
are unidirectional. It's kind of hard to pass information from Backed -> BackendConfig
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should train_func_context
be part of the Backend
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or at the very least the default one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh hm maybe that won't work because we construct the train loop before the backend...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Support xgboost 2.1.0, which was recently released and changed some of the distributed setup APIs. --------- Signed-off-by: Justin Yu <[email protected]> Signed-off-by: Dev <[email protected]>
Why are these changes needed?
xgboost 2.1.0 was recently released, and it changed some of the distributed setup APIs.
In particular:
CollectiveCommunicator
was changed to not read from environment variables. Instead, an argument dict is required to be passed in: dmlc/xgboost@a5a5810#diff-a74bc610352aa00eda4ae89c1f3a51c33b934f50f055f366def49654efb42992RabitTracker
API was updated with a few renamed methods and changed behavior:worker_envs()
->worker_args()
RabitTracker.wait_for
must be run as a separate thread in order forworker_args
to return properly: dmlc/xgboost@a5a5810#diff-94301e6ca68aefc564a0d617db4ab2de3425b2cca9d66fe95ad8c7ce97399c14R182-R186This PR branches the setup logic between pre 2.1.0 and post 2.1.0. We should eventually drop pre-2.1.0 support.
Testing
This PR also updates the tested xgboost version to 2.1.0. Pre-2.1.0 has been tested manually.
Related issue number
Closes #46476
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.