Skip to content

Commit

Permalink
[backport] Fix feature types param (#8772) (#8801)
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <[email protected]>
Co-authored-by: WeichenXu <[email protected]>
  • Loading branch information
trivialfis and WeichenXu123 committed Feb 14, 2023
1 parent 60303db commit 08a547f
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@
}


# TODO: supply hint message for all other unsupported params.
_unsupported_params_hint_message = {
"enable_categorical": "`xgboost.spark` estimators do not have 'enable_categorical' param, "
"but you can set `feature_types` param and mark categorical features with 'c' string."
}


class _SparkXGBParams(
HasFeaturesCol,
HasLabelCol,
Expand Down Expand Up @@ -523,7 +530,10 @@ def setParams(self, **kwargs): # pylint: disable=invalid-name
or k in _unsupported_predict_params
or k in _unsupported_train_params
):
raise ValueError(f"Unsupported param '{k}'.")
err_msg = _unsupported_params_hint_message.get(
k, f"Unsupported param '{k}'."
)
raise ValueError(err_msg)
_extra_params[k] = v
_existing_extra_params = self.getOrDefault(self.arbitrary_params_dict)
self._set(arbitrary_params_dict={**_existing_extra_params, **_extra_params})
Expand Down Expand Up @@ -749,6 +759,8 @@ def _fit(self, dataset):
"feature_weights": self.getOrDefault(self.feature_weights),
"missing": float(self.getOrDefault(self.missing)),
}
if dmatrix_kwargs["feature_types"] is not None:
dmatrix_kwargs["enable_categorical"] = True
booster_params["nthread"] = cpu_per_task
use_gpu = self.getOrDefault(self.use_gpu)

Expand Down

0 comments on commit 08a547f

Please sign in to comment.