-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jvm-packages] Add the new device
parameter.
#9385
Conversation
cc @wbo4958 . |
device
parameter.device
parameter.
@dotbg Could you please help take a look into the PR when you are available? I'm not an expert in Scala/Spark. |
@trivialfis At the first glance the code looks ok. I wonder whether the |
@dotbg |
@@ -77,7 +77,8 @@ public void testBooster() throws XGBoostError { | |||
put("objective", "binary:logistic"); | |||
put("num_round", round); | |||
put("num_workers", 1); | |||
put("tree_method", "gpu_hist"); | |||
put("tree_method", "hist"); | |||
put("device", "cuda"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will the put("tree_method", "gpu_hist");
also work?
I see, in this case, it may be good(but not a showstopper) to align naming in other places as well. |
It should work but with a warning. I will add a test tomorrow.
Could you please elaborate on this? Apologies for not having providing a clearer context before. I'm currently walking through each interface #7308 (comment) . The syntax is documented in here, which was part of #9362 . I think the flink interface doesn't support GPU yet. But I will double-check the native Java interface and the Scala interface, I think they don't have hard-coded parameters that require changes (feel free to correct me). The R interface doesn't have hard-coded parameters, and the CRAN package doesn't support GPU. The Python interface is mostly handled in previous PRs, I will have some more specialized handling for PySpark. The naming of parameters are consistent. |
@trivialfis well, one of my concerns is that the package names will contain |
I think this part should be fine, we have documents on how GPU support is achieved for both general XGB packages and the JVM packages, along with notes for CUDA being the only option at the moment. |
jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
Outdated
Show resolved
Hide resolved
jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala
Show resolved
Hide resolved
@@ -137,8 +137,12 @@ object GpuPreXGBoost extends PreXGBoostProvider { | |||
val (Seq(labelName, weightName, marginName), feturesCols, groupName, evalSets) = | |||
estimator match { | |||
case est: XGBoostEstimatorCommon => | |||
require(est.isDefined(est.treeMethod) && est.getTreeMethod.equals("gpu_hist"), | |||
s"GPU train requires tree_method set to gpu_hist") | |||
require( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wbo4958 The check is here.
Related #7308 .
This PR is to add the
device
parameter for the Spark package. For the core implementation PR, please see #9362 .