diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 5224a0b49a991..2e7e2cf820f9d 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -115,22 +115,31 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. - {% highlight python %} -from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint + +data = sc.textFile("data/mllib/sample_naive_bayes_data.txt") + +# Preprocessing +splitData = data.map(lambda line: line.split(',')) +parsedData = splitData.map( + lambda parts: LabeledPoint( + float(parts[0]), + Vectors.dense(map(lambda x: float(x), parts[1].split(' '))) + ) + ) -# an RDD of LabeledPoint -data = sc.parallelize([ - LabeledPoint(0.0, [0.0, 0.0]) - ... # more labeled points -]) +# Split data into training (60%) and test (40%) +training, test = parsedData.randomSplit([0.6, 0.4], seed = 0) # Train a naive Bayes model. -model = NaiveBayes.train(data, 1.0) +model = NaiveBayes.train(training, 1.0) -# Make prediction. -prediction = model.predict([0.0, 0.0]) +# Make prediction and test accuracy. +predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) +accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() {% endhighlight %}