Skip to content

Commit

Permalink
[SPARK-44908][ML][CONNECT] Fix cross validator foldCol param function…
Browse files Browse the repository at this point in the history
…ality

### What changes were proposed in this pull request?

Fix cross validator foldCol param functionality.
In main branch the code calls `df.rdd` APIs but it is not supported in spark connect

### Why are the changes needed?

Bug fix.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

UT.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#42605 from WeichenXu123/fix-tuning-connect-foldCol.

Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Aug 23, 2023
1 parent 4d90c59 commit 0d1b597
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 17 deletions.
24 changes: 7 additions & 17 deletions python/pyspark/ml/connect/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@
)
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasParallelism, HasSeed
from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
from pyspark.sql.types import BooleanType
from pyspark.sql.functions import col, lit, rand
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import SparkSession

Expand Down Expand Up @@ -477,23 +476,14 @@ def _kFold(self, dataset: DataFrame) -> List[Tuple[DataFrame, DataFrame]]:
train = df.filter(~condition)
datasets.append((train, validation))
else:
# Use user-specified fold numbers.
def checker(foldNum: int) -> bool:
if foldNum < 0 or foldNum >= nFolds:
raise ValueError(
"Fold number must be in range [0, %s), but got %s." % (nFolds, foldNum)
)
return True

checker_udf = UserDefinedFunction(checker, BooleanType())
# TODO:
# Add verification that foldCol column values are in range [0, nFolds)
for i in range(nFolds):
training = dataset.filter(checker_udf(dataset[foldCol]) & (col(foldCol) != lit(i)))
validation = dataset.filter(
checker_udf(dataset[foldCol]) & (col(foldCol) == lit(i))
)
if training.rdd.getNumPartitions() == 0 or len(training.take(1)) == 0:
training = dataset.filter(col(foldCol) != lit(i))
validation = dataset.filter(col(foldCol) == lit(i))
if training.isEmpty():
raise ValueError("The training data at fold %s is empty." % i)
if validation.rdd.getNumPartitions() == 0 or len(validation.take(1)) == 0:
if validation.isEmpty():
raise ValueError("The validation data at fold %s is empty." % i)
datasets.append((training, validation))

Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,31 @@ def _verify_cv_saved_params(instance, loaded_instance):
np.testing.assert_allclose(cv_model.avgMetrics, loaded_cv_model.avgMetrics)
np.testing.assert_allclose(cv_model.stdMetrics, loaded_cv_model.stdMetrics)

def test_crossvalidator_with_fold_col(self):
sk_dataset = load_breast_cancer()

train_dataset = self.spark.createDataFrame(
zip(
sk_dataset.data.tolist(),
[int(t) for t in sk_dataset.target],
[int(i % 3) for i in range(len(sk_dataset.target))],
),
schema="features: array<double>, label: long, fold: long",
)

lorv2 = LORV2(numTrainWorkers=2)

grid2 = ParamGridBuilder().addGrid(lorv2.maxIter, [2, 200]).build()
cv = CrossValidator(
estimator=lorv2,
estimatorParamMaps=grid2,
parallelism=2,
evaluator=BinaryClassificationEvaluator(),
foldCol="fold",
numFolds=3,
)
cv.fit(train_dataset)


class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit 0d1b597

Please sign in to comment.