diff --git a/python/pyspark/ml/connect/classification.py b/python/pyspark/ml/connect/classification.py index f8b525db8edd6..ca6e01e9577c4 100644 --- a/python/pyspark/ml/connect/classification.py +++ b/python/pyspark/ml/connect/classification.py @@ -41,7 +41,7 @@ ) from pyspark.ml.connect.base import Predictor, PredictionModel from pyspark.ml.connect.io_utils import ParamsReadWrite, CoreModelReadWrite -from pyspark.sql.functions import lit, count, countDistinct +from pyspark.sql import functions as sf import torch import torch.nn as torch_nn @@ -232,18 +232,20 @@ def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LogisticRegressionMo num_train_workers ) - # TODO: check label values are in range of [0, num_classes) - num_rows, num_classes = dataset.agg( - count(lit(1)), countDistinct(self.getLabelCol()) + num_rows, num_features, classes = dataset.select( + sf.count(sf.lit(1)), + sf.first(sf.array_size(self.getFeaturesCol())), + sf.collect_set(self.getLabelCol()), ).head() # type: ignore[misc] - num_batches_per_worker = math.ceil(num_rows / num_train_workers / batch_size) - num_samples_per_worker = num_batches_per_worker * batch_size - - num_features = len(dataset.select(self.getFeaturesCol()).head()[0]) # type: ignore[index] - + num_classes = len(classes) if num_classes < 2: raise ValueError("Training dataset distinct labels must >= 2.") + if any(c not in range(0, num_classes) for c in classes): + raise ValueError("Training labels must be integers in [0, numClasses).") + + num_batches_per_worker = math.ceil(num_rows / num_train_workers / batch_size) + num_samples_per_worker = num_batches_per_worker * batch_size # TODO: support GPU. distributor = TorchDistributor(