diff --git a/lightgbm_ray/sklearn.py b/lightgbm_ray/sklearn.py index b58997c..8c21686 100644 --- a/lightgbm_ray/sklearn.py +++ b/lightgbm_ray/sklearn.py @@ -36,7 +36,7 @@ from lightgbm import LGBMModel, LGBMClassifier, LGBMRegressor # LGBMRanker from lightgbm.basic import _choose_param_value, _ConfigAliases from xgboost_ray.sklearn import (_wrap_evaluation_matrices, - _check_if_params_are_ray_dmatrix) + _check_if_params_are_ray_dmatrix, RayXGBMixin) from lightgbm_ray.main import train, predict, RayDMatrix, RayParams import warnings @@ -83,7 +83,14 @@ def _treat_method_doc(doc: str, insert_before: str) -> str: return doc -class _RayLGBMModel: +class _RayLGBMModel(RayXGBMixin): + def _ray_get_wrap_evaluation_matrices_compat_kwargs( + self, label_transform=None) -> dict: + self.enable_categorical = False + self.feature_types = None + return super()._ray_get_wrap_evaluation_matrices_compat_kwargs( + label_transform=label_transform) + def _ray_set_ray_params_n_jobs( self, ray_params: Optional[Union[RayParams, dict]], n_jobs: Optional[int]) -> RayParams: @@ -133,7 +140,7 @@ def _ray_fit(self, eval_init_score) if train_dmatrix is None: - wrap_evaluation_matrices_kwargs = dict( + train_dmatrix, evals = _wrap_evaluation_matrices( missing=None, X=X, y=y, @@ -150,18 +157,9 @@ def _ray_fit(self, # changed in xgboost-ray: create_dmatrix=lambda **kwargs: RayDMatrix(**{ **kwargs, - **ray_dmatrix_params - })) - try: - train_dmatrix, evals = _wrap_evaluation_matrices( - **wrap_evaluation_matrices_kwargs) - except TypeError as e: - if "enable_categorical" in str(e): - train_dmatrix, evals = _wrap_evaluation_matrices( - **wrap_evaluation_matrices_kwargs, - enable_categorical=False) - else: - raise e + **ray_dmatrix_params, + }), + **self._ray_get_wrap_evaluation_matrices_compat_kwargs()) eval_names = eval_names or [] diff --git a/requirements-test.txt b/requirements-test.txt index 9f924a4..e33dc7c 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -5,7 +5,7 @@ yapf==0.23.0 petastorm pytest pyarrow -ray[tune] +ray[tune, data] scikit-learn modin git+https://github.com/ray-project/xgboost_ray.git @@ -13,4 +13,5 @@ parameterized packaging # workaround for now +protobuf<4.0.0 tensorboardX==2.2