From 66ce18c04466729e77ead601db4e18f928d440d4 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 11 May 2015 13:00:37 -0700 Subject: [PATCH] some cleanups before sending to Xiangrui --- python/pyspark/ml/param/__init__.py | 9 --------- python/pyspark/ml/tuning.py | 17 ++++++++++++----- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 76baa36ff5cf5..71b6d2e10c99a 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -243,12 +243,3 @@ def _copyValues(self, to, extra={}): if paramMap.has_key(p) and to.hasParam(p.name): to._set((p.name, paramMap[p])) return to - - @staticmethod - def _copyParamMap(paramMap, to): - """ - Create a copy of the given ParamMap, but with parameter - :param paramMap: - :param to: - :return: - """ \ No newline at end of file diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 8b1b2c6664c40..19dd84899c595 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -232,14 +232,21 @@ def fit(self, dataset, params={}): def copy(self, extra={}): """ Creates a copy of this instance with a randomly generated uid - and some extra params. This copies the underlying estimator, creates a deep copy of the embedded paramMap, and - copies the embedded and extra parameters over. + and some extra params. This copies the underlying estimator, + evaluator, and estimatorParamMap, creates a deep copy of the + embedded paramMap, and copies the embedded and extra parameters + over. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ - paramMap = self.extractParamMap(extra) - stages = map(lambda stage: stage.copy(extra), paramMap[self.stages]) - return CrossValidator().setStages(stages) + newCV = Params.copy(self, extra) + if self.isSet(self.estimator): + newCV.setEstimator(self.getEstimator().copy(extra)) + if self.isSet(self.estimatorParamMaps): + newCV.setEstimatorParamMaps(self.getEstimatorParamMaps().MAGIC_COPY_TO_BE_IMPLEMENTED(extra)) # TODO + if self.isSet(self.evaluator): + newCV.setEvaluator(self.getEvaluator().copy(extra)) + return newCV class CrossValidatorModel(Model):