From 5c532baed28bb88442e115e27e0b2c7fd20659df Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 11 Apr 2023 21:27:33 +0800 Subject: [PATCH] init Signed-off-by: Weichen Xu --- python-package/xgboost/spark/core.py | 6 +++++- tests/test_distributed/test_with_spark/test_spark_local.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 1a614f51f0a3..ec47a8c23477 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -931,7 +931,11 @@ def _run_job(): result_xgb_model = self._convert_to_sklearn_model( bytearray(booster, "utf-8"), config ) - return self._copyValues(self._create_pyspark_model(result_xgb_model)) + spark_model = self._create_pyspark_model(result_xgb_model) + # According to pyspark ML convention, the model uid should be the same + # with estimator uid. + spark_model._resetUid(self.uid) + return self._copyValues(spark_model) def write(self): """ diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 0ffdb2a2bde2..a5e0f028a060 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -464,6 +464,7 @@ class TestPySparkLocal: def test_regressor_basic(self, reg_data: RegData) -> None: regressor = SparkXGBRegressor(pred_contrib_col="pred_contribs") model = regressor.fit(reg_data.reg_df_train) + assert regressor.uid == model.uid pred_result = model.transform(reg_data.reg_df_test).collect() for row in pred_result: np.testing.assert_equal(row.prediction, row.expected_prediction)