Skip to content

Commit

Permalink
[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None
Browse files Browse the repository at this point in the history
If initial model passed to GMM is not empty it causes net.razorvine.pickle.PickleException. It can be fixed by converting initialModel.weights to list.

Author: zero323 <[email protected]>

Closes #10644 from zero323/SPARK-12006.

(cherry picked from commit 592f649)
Signed-off-by: Joseph K. Bradley <[email protected]>
  • Loading branch information
zero323 authored and jkbradley committed Jan 7, 2016
1 parent 33ab236 commit 7b4fdf3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
if initialModel.k != k:
raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
% (initialModel.k, k))
initialModelWeights = initialModel.weights
initialModelWeights = list(initialModel.weights)
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,18 @@ def test_gmm_deterministic(self):
for c1, c2 in zip(clusters1.weights, clusters2.weights):
self.assertEquals(round(c1, 7), round(c2, 7))

def test_gmm_with_initial_model(self):
from pyspark.mllib.clustering import GaussianMixture
data = self.sc.parallelize([
(-10, -5), (-9, -4), (10, 5), (9, 4)
])

gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
maxIterations=10, seed=63)
gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
maxIterations=10, seed=63, initialModel=gmm1)
self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)

def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\
Expand Down

0 comments on commit 7b4fdf3

Please sign in to comment.