From 07923c42fcb4d210333b3882490e23f33dc4822f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 16 Dec 2014 13:48:42 -0800 Subject: [PATCH 01/18] support kafka in Python --- python/pyspark/streaming/kafka.py | 71 +++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 python/pyspark/streaming/kafka.py diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py new file mode 100644 index 0000000000000..f383362acfac3 --- /dev/null +++ b/python/pyspark/streaming/kafka.py @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from py4j.java_collections import MapConverter +from py4j.java_gateway import java_import, Py4JError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['KafkaUtils'] + + +class KafkaUtils(object): + + @staticmethod + def createStream(ssc, zkQuorum, groupId, topics, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + keyDecoder=None, valueDecoder=None): + """ + Create an input stream that pulls messages from a Kafka Broker. + + :param ssc: StreamingContext object + :param zkQuorum: Zookeeper quorum (hostname:port,hostname:port,..). + :param groupId: The group id for this consumer. + :param topics: Dict of (topic_name -> numPartitions) to consume. + Each partition is consumed in its own thread. + :param storageLevel: RDD storage level. + :param keyDecoder: A function used to decode key + :param valueDecoder: A function used to decode value + :return: A DStream object + """ + java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils") + + if not isinstance(topics, dict): + raise TypeError("topics should be dict") + jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + try: + jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, zkQuorum, groupId, jtopics, jlevel) + except Py4JError, e: + if 'call a package' in e.message: + print "No kafka package, please build it and add it into classpath:" + print " $ sbt/sbt streaming-kafka/package" + print " $ bin/submit --driver-class-path external/kafka/target/scala-2.10/" \ + "spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" + raise Exception("No kafka package") + raise e + ser = PairDeserializer(UTF8Deserializer(), UTF8Deserializer()) + stream = DStream(jstream, ssc, ser) + + if keyDecoder is not None: + stream = stream.map(lambda (k, v): (keyDecoder(k), v)) + if valueDecoder is not None: + stream = stream.mapValues(valueDecoder) + return stream \ No newline at end of file From 75d485e65b75a7a5da91a37ff42a9bb7cd82dcf6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 16 Dec 2014 14:18:59 -0800 Subject: [PATCH 02/18] add mqtt --- python/pyspark/streaming/kafka.py | 2 +- python/pyspark/streaming/mqtt.py | 53 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 python/pyspark/streaming/mqtt.py diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index f383362acfac3..ed52f853658fd 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -68,4 +68,4 @@ def createStream(ssc, zkQuorum, groupId, topics, stream = stream.map(lambda (k, v): (keyDecoder(k), v)) if valueDecoder is not None: stream = stream.mapValues(valueDecoder) - return stream \ No newline at end of file + return stream diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py new file mode 100644 index 0000000000000..78e4cb7f14649 --- /dev/null +++ b/python/pyspark/streaming/mqtt.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from py4j.java_gateway import java_import, Py4JError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import UTF8Deserializer +from pyspark.streaming import DStream + +__all__ = ['MQTTUtils'] + + +class MQTTUtils(object): + + @staticmethod + def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input stream that receives messages pushed by a MQTT publisher. + + :param ssc: StreamingContext object + :param brokerUrl: Url of remote MQTT publisher + :param topic: Topic name to subscribe to + :param storageLevel: RDD storage level. + :return: A DStream object + """ + java_import(ssc._jvm, "org.apache.spark.streaming.mqtt.MQTTUtils") + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + try: + jstream = ssc._jvm.MQTTUtils.createStream(ssc._jssc, brokerUrl, topic, jlevel) + except Py4JError, e: + if 'call a package' in e.message: + print "No MQTT package, please build it and add it into classpath:" + print " $ sbt/sbt streaming-mqtt/package" + print " $ bin/submit --driver-class-path external/mqtt/target/scala-2.10/" \ + "spark-streaming-mqtt_2.10-1.3.0-SNAPSHOT.jar" + raise Exception("No mqtt package") + raise e + return DStream(jstream, ssc, UTF8Deserializer()) From 048dbe6c9ec4bff4eeee52c70e4e18d48d3075e0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 16 Dec 2014 14:27:43 -0800 Subject: [PATCH 03/18] fix python style --- python/pyspark/streaming/kafka.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index ed52f853658fd..24d62db1979b9 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -52,7 +52,8 @@ def createStream(ssc, zkQuorum, groupId, topics, jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) try: - jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, zkQuorum, groupId, jtopics, jlevel) + jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, zkQuorum, groupId, jtopics, + jlevel) except Py4JError, e: if 'call a package' in e.message: print "No kafka package, please build it and add it into classpath:" From 5697a012def1b8508d21d96ccccf2afb7d6705cf Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Dec 2014 15:44:33 -0800 Subject: [PATCH 04/18] bypass decoder in scala --- python/pyspark/streaming/kafka.py | 39 +++++++++++++++++++------------ 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 24d62db1979b9..6d0cb5d4c1f63 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,23 +15,26 @@ # limitations under the License. # - from py4j.java_collections import MapConverter from py4j.java_gateway import java_import, Py4JError from pyspark.storagelevel import StorageLevel -from pyspark.serializers import PairDeserializer, UTF8Deserializer +from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream __all__ = ['KafkaUtils'] +def utf8_decoder(s): + return s.decode('utf-8') + + class KafkaUtils(object): @staticmethod def createStream(ssc, zkQuorum, groupId, topics, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, - keyDecoder=None, valueDecoder=None): + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ Create an input stream that pulls messages from a Kafka Broker. @@ -47,26 +50,32 @@ def createStream(ssc, zkQuorum, groupId, topics, """ java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils") + param = { + "zookeeper.connect": zkQuorum, + "group.id": groupId, + "zookeeper.connection.timeout.ms": "10000", + } if not isinstance(topics, dict): raise TypeError("topics should be dict") jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) + jparam = MapConverter().convert(param, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + def getClassByName(name): + return ssc._jvm.org.apache.spark.util.Utils.classForName(name) + try: - jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, zkQuorum, groupId, jtopics, - jlevel) + array = getClassByName("[B") + decoder = getClassByName("kafka.serializer.DefaultDecoder") + jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder, + jparam, jtopics, jlevel) except Py4JError, e: if 'call a package' in e.message: print "No kafka package, please build it and add it into classpath:" print " $ sbt/sbt streaming-kafka/package" - print " $ bin/submit --driver-class-path external/kafka/target/scala-2.10/" \ - "spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" - raise Exception("No kafka package") + print " $ bin/submit --driver-class-path lib_managed/jars/kafka_2.10-0.8.0.jar:" \ + "external/kafka/target/scala-2.10/spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" raise e - ser = PairDeserializer(UTF8Deserializer(), UTF8Deserializer()) + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) - - if keyDecoder is not None: - stream = stream.map(lambda (k, v): (keyDecoder(k), v)) - if valueDecoder is not None: - stream = stream.mapValues(valueDecoder) - return stream + return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v))) From 98c8d179d3ff264d03eabc3ddd72936d95e6e305 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Dec 2014 15:58:29 -0800 Subject: [PATCH 05/18] fix python style --- python/pyspark/streaming/kafka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 6d0cb5d4c1f63..c27699fee6b83 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -74,7 +74,7 @@ def getClassByName(name): print "No kafka package, please build it and add it into classpath:" print " $ sbt/sbt streaming-kafka/package" print " $ bin/submit --driver-class-path lib_managed/jars/kafka_2.10-0.8.0.jar:" \ - "external/kafka/target/scala-2.10/spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" + "external/kafka/target/scala-2.10/spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) From f6ce899abd435f36f7c5907523c643cc8b0e61ed Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Jan 2015 13:28:39 -0800 Subject: [PATCH 06/18] add example and fix bugs --- .../apache/spark/api/python/PythonRDD.scala | 55 ++++++++++++------- .../main/python/streaming/kafka_wordcount.py | 55 +++++++++++++++++++ python/pyspark/serializers.py | 7 ++- python/pyspark/streaming/kafka.py | 8 ++- python/pyspark/streaming/mqtt.py | 53 ------------------ 5 files changed, 100 insertions(+), 78 deletions(-) create mode 100644 examples/src/main/python/streaming/kafka_wordcount.py delete mode 100644 python/pyspark/streaming/mqtt.py diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index bad40e6529f74..b47b381374dc7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -313,6 +313,7 @@ private object SpecialLengths { val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 val END_OF_STREAM = -4 + val NULL = -5 } private[spark] object PythonRDD extends Logging { @@ -374,49 +375,61 @@ private[spark] object PythonRDD extends Logging { // The right way to implement this would be to use TypeTags to get the full // type of T. Since I don't want to introduce breaking changes throughout the // entire Spark API, I have to use this hacky approach: + def write(bytes: Array[Byte]) { + if (bytes == null) { + dataOut.writeInt(SpecialLengths.NULL) + } else { + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + } + def writeS(str: String) { + if (str == null) { + dataOut.writeInt(SpecialLengths.NULL) + } else { + writeUTF(str, dataOut) + } + } if (iter.hasNext) { val first = iter.next() val newIter = Seq(first).iterator ++ iter first match { case arr: Array[Byte] => - newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes => - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } + newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(write) case string: String => - newIter.asInstanceOf[Iterator[String]].foreach { str => - writeUTF(str, dataOut) - } + newIter.asInstanceOf[Iterator[String]].foreach(writeS) case stream: PortableDataStream => newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => - val bytes = stream.toArray() - dataOut.writeInt(bytes.length) - dataOut.write(bytes) + write(stream.toArray()) } case (key: String, stream: PortableDataStream) => newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { case (key, stream) => - writeUTF(key, dataOut) - val bytes = stream.toArray() - dataOut.writeInt(bytes.length) - dataOut.write(bytes) + writeS(key) + write(stream.toArray()) } case (key: String, value: String) => newIter.asInstanceOf[Iterator[(String, String)]].foreach { case (key, value) => - writeUTF(key, dataOut) - writeUTF(value, dataOut) + writeS(key) + writeS(value) } case (key: Array[Byte], value: Array[Byte]) => newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { case (key, value) => - dataOut.writeInt(key.length) - dataOut.write(key) - dataOut.writeInt(value.length) - dataOut.write(value) + write(key) + write(value) + } + // key is null + case (null, v:Array[Byte]) => + newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { + case (key, value) => + write(key) + write(value) } + case other => - throw new SparkException("Unexpected element type " + first.getClass) + throw new SparkException("Unexpected element type " + other.getClass) } } } diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py new file mode 100644 index 0000000000000..400c05fb7a05b --- /dev/null +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first + $ bin/zookeeper-server-start.sh config/zookeeper.properties + $ bin/kafka-server-start.sh config/server.properties + $ bin/kafka-console-producer.sh --broker-list localhost:9092 --topic test + + and then run the example + `$ bin/spark-submit --driver-class-path lib_managed/jars/kafka_*.jar:\ + external/kafka/target/scala-*/spark-streaming-kafka_*.jar examples/src/main/python/\ + streaming/kafka_wordcount.py localhost:2181 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: network_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingKafkaWordCount") + ssc = StreamingContext(sc, 1) + + zkQuorum, topic = sys.argv[1:] + lines = KafkaUtils.createStream(ssc, zkQuorum, "spark-streaming-consumer", {topic: 1}) + counts = lines.map(lambda x: x[1]).flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index bd08c9a6d20d6..3cec646f3336d 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -70,6 +70,7 @@ class SpecialLengths(object): PYTHON_EXCEPTION_THROWN = -2 TIMING_DATA = -3 END_OF_STREAM = -4 + NULL = -5 class Serializer(object): @@ -145,8 +146,10 @@ def _read_with_length(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError + if length == SpecialLengths.NULL: + return None obj = stream.read(length) - if obj == "": + if len(obj) < length: raise EOFError return self.loads(obj) @@ -480,6 +483,8 @@ def loads(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError + if length == SpecialLengths.NULL: + return None s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index c27699fee6b83..f52d0b535094f 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -22,11 +22,12 @@ from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream -__all__ = ['KafkaUtils'] +__all__ = ['KafkaUtils', 'utf8_decoder'] def utf8_decoder(s): - return s.decode('utf-8') + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') class KafkaUtils(object): @@ -70,7 +71,8 @@ def getClassByName(name): jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder, jparam, jtopics, jlevel) except Py4JError, e: - if 'call a package' in e.message: + # TODO: use --jar once it also work on driver + if not e.message or 'call a package' in e.message: print "No kafka package, please build it and add it into classpath:" print " $ sbt/sbt streaming-kafka/package" print " $ bin/submit --driver-class-path lib_managed/jars/kafka_2.10-0.8.0.jar:" \ diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py deleted file mode 100644 index 78e4cb7f14649..0000000000000 --- a/python/pyspark/streaming/mqtt.py +++ /dev/null @@ -1,53 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -from py4j.java_gateway import java_import, Py4JError - -from pyspark.storagelevel import StorageLevel -from pyspark.serializers import UTF8Deserializer -from pyspark.streaming import DStream - -__all__ = ['MQTTUtils'] - - -class MQTTUtils(object): - - @staticmethod - def createStream(ssc, brokerUrl, topic, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): - """ - Create an input stream that receives messages pushed by a MQTT publisher. - - :param ssc: StreamingContext object - :param brokerUrl: Url of remote MQTT publisher - :param topic: Topic name to subscribe to - :param storageLevel: RDD storage level. - :return: A DStream object - """ - java_import(ssc._jvm, "org.apache.spark.streaming.mqtt.MQTTUtils") - jlevel = ssc._sc._getJavaStorageLevel(storageLevel) - try: - jstream = ssc._jvm.MQTTUtils.createStream(ssc._jssc, brokerUrl, topic, jlevel) - except Py4JError, e: - if 'call a package' in e.message: - print "No MQTT package, please build it and add it into classpath:" - print " $ sbt/sbt streaming-mqtt/package" - print " $ bin/submit --driver-class-path external/mqtt/target/scala-2.10/" \ - "spark-streaming-mqtt_2.10-1.3.0-SNAPSHOT.jar" - raise Exception("No mqtt package") - raise e - return DStream(jstream, ssc, UTF8Deserializer()) From eea16a79e741255548ef2e006db3948771a47e0d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 8 Jan 2015 13:29:35 -0800 Subject: [PATCH 07/18] refactor --- examples/src/main/python/streaming/kafka_wordcount.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 400c05fb7a05b..8b9ae9903aacb 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -45,8 +45,9 @@ ssc = StreamingContext(sc, 1) zkQuorum, topic = sys.argv[1:] - lines = KafkaUtils.createStream(ssc, zkQuorum, "spark-streaming-consumer", {topic: 1}) - counts = lines.map(lambda x: x[1]).flatMap(lambda line: line.split(" ")) \ + kvs = KafkaUtils.createStream(ssc, zkQuorum, "spark-streaming-consumer", {topic: 1}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ .map(lambda word: (word, 1)) \ .reduceByKey(lambda a, b: a+b) counts.pprint() From aea89538dcb9b80111f98df881d345d4e87e91aa Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 21 Jan 2015 17:31:30 -0800 Subject: [PATCH 08/18] Kafka-assembly for Python API --- external/kafka-assembly/pom.xml | 106 ++++++++++++++++++++++++++++++++ make-distribution.sh | 1 + pom.xml | 1 + project/SparkBuild.scala | 8 ++- 4 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 external/kafka-assembly/pom.xml diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml new file mode 100644 index 0000000000000..503fc129dc4f2 --- /dev/null +++ b/external/kafka-assembly/pom.xml @@ -0,0 +1,106 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.3.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kafka-assembly_2.10 + jar + Spark Project External Kafka Assembly + http://spark.apache.org/ + + + streaming-kafka-assembly + scala-${scala.binary.version} + spark-streaming-kafka-assembly-${project.version}.jar + ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} + + + + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${spark.jar} + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/make-distribution.sh b/make-distribution.sh index 45c99e42e5a5b..dd00cca4428b8 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -181,6 +181,7 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +cp "$FWDIR"/external/kafka/scala*/*kafka*assembly*.jar "$DISTDIR/lib/" # This will fail if the -Pyarn profile is not provided # In this case, silence the error and ignore the return code of this command cp "$FWDIR"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : diff --git a/pom.xml b/pom.xml index 9e6fe09d95bbe..95950bf0b2a97 100644 --- a/pom.xml +++ b/pom.xml @@ -1422,6 +1422,7 @@ external/kafka + external/kafka-assembly diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index ff8cf81b286af..d679b126fd047 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -42,8 +42,9 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn) = - Seq("assembly", "examples", "network-yarn").map(ProjectRef(buildLocation, _)) + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly") + .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. @@ -298,6 +299,9 @@ object Assembly { if (mName.contains("network-yarn")) { // This must match the same name used in maven (see network/yarn/pom.xml) "spark-" + v + "-yarn-shuffle.jar" + } else if (mName.contains("streaming-kafka-assembly")) { + // This must match the same name used in maven (see external/kafka-assembly/pom.xml) + mName + "-" + v + ".jar" } else { mName + "-" + v + "-hadoop" + Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" From 2c567a5d55c465d706026c2395e9025fad9dbd68 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 22 Jan 2015 00:01:02 -0800 Subject: [PATCH 09/18] update logging and comment --- examples/src/main/python/streaming/kafka_wordcount.py | 7 ++++--- python/pyspark/streaming/kafka.py | 7 +++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 8b9ae9903aacb..dad760aa4db54 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -22,12 +22,13 @@ To run this on your local machine, you need to setup Kafka and create a producer first $ bin/zookeeper-server-start.sh config/zookeeper.properties $ bin/kafka-server-start.sh config/server.properties + $ bin/kafka-topics.sh --create --zookeeper localhost:2181 --partitions 1 --topic test $ bin/kafka-console-producer.sh --broker-list localhost:9092 --topic test and then run the example - `$ bin/spark-submit --driver-class-path lib_managed/jars/kafka_*.jar:\ - external/kafka/target/scala-*/spark-streaming-kafka_*.jar examples/src/main/python/\ - streaming/kafka_wordcount.py localhost:2181 test` + `$ bin/spark-submit --driver-class-path external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ + localhost:2181 test` """ import sys diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index f52d0b535094f..2e898c06fcf8f 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -73,10 +73,9 @@ def getClassByName(name): except Py4JError, e: # TODO: use --jar once it also work on driver if not e.message or 'call a package' in e.message: - print "No kafka package, please build it and add it into classpath:" - print " $ sbt/sbt streaming-kafka/package" - print " $ bin/submit --driver-class-path lib_managed/jars/kafka_2.10-0.8.0.jar:" \ - "external/kafka/target/scala-2.10/spark-streaming-kafka_2.10-1.3.0-SNAPSHOT.jar" + print "No kafka package, please put the assembly jar into classpath:" + print " $ bin/submit --driver-class-path external/kafka-assembly/target/" + \ + "scala-*/spark-streaming-kafka-assembly-*.jar" raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) From 97386b3debd5f352b61dfed194ab9495fecbe834 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 22 Jan 2015 00:08:06 -0800 Subject: [PATCH 10/18] address comment --- core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 98412f1faa9ea..6f5d3dda377de 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -383,6 +383,7 @@ private[spark] object PythonRDD extends Logging { dataOut.write(bytes) } } + def writeS(str: String) { if (str == null) { dataOut.writeInt(SpecialLengths.NULL) @@ -390,6 +391,7 @@ private[spark] object PythonRDD extends Logging { writeUTF(str, dataOut) } } + if (iter.hasNext) { val first = iter.next() val newIter = Seq(first).iterator ++ iter From 370ba61571b98e9bdfb6636852d4404687143853 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 26 Jan 2015 12:25:12 -0800 Subject: [PATCH 11/18] Update kafka.py fix spark-submit --- python/pyspark/streaming/kafka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 2e898c06fcf8f..ff1a951b6f11b 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -74,7 +74,7 @@ def getClassByName(name): # TODO: use --jar once it also work on driver if not e.message or 'call a package' in e.message: print "No kafka package, please put the assembly jar into classpath:" - print " $ bin/submit --driver-class-path external/kafka-assembly/target/" + \ + print " $ bin/spark-submit --driver-class-path external/kafka-assembly/target/" + \ "scala-*/spark-streaming-kafka-assembly-*.jar" raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) From 31e2317a31c90b23a7b085c7fd5a1de8998194a6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 27 Jan 2015 11:21:11 -0800 Subject: [PATCH 12/18] Update kafka_wordcount.py --- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index dad760aa4db54..7a023ee410735 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -39,7 +39,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: network_wordcount.py " + print >> sys.stderr, "Usage: kafka_wordcount.py " exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") From dc1eed0a6af190d5cf07dedcb0607a0a76e45d64 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 27 Jan 2015 11:47:20 -0800 Subject: [PATCH 13/18] Update kafka_wordcount.py --- examples/src/main/python/streaming/kafka_wordcount.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 7a023ee410735..58201f967664a 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -19,12 +19,9 @@ Counts words in UTF8 encoded, '\n' delimited text received from the network every second. Usage: network_wordcount.py - To run this on your local machine, you need to setup Kafka and create a producer first - $ bin/zookeeper-server-start.sh config/zookeeper.properties - $ bin/kafka-server-start.sh config/server.properties - $ bin/kafka-topics.sh --create --zookeeper localhost:2181 --partitions 1 --topic test - $ bin/kafka-console-producer.sh --broker-list localhost:9092 --topic test - + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + and then run the example `$ bin/spark-submit --driver-class-path external/kafka-assembly/target/scala-*/\ spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ From a74da876085ae51c93464d2d73787e226457bda0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 27 Jan 2015 22:52:23 -0800 Subject: [PATCH 14/18] address comments --- .../apache/spark/api/python/PythonRDD.scala | 35 ++++++++++--------- python/pyspark/serializers.py | 2 +- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6f5d3dda377de..ff5f7a0e0d3fc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -372,10 +372,8 @@ private[spark] object PythonRDD extends Logging { } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { - // The right way to implement this would be to use TypeTags to get the full - // type of T. Since I don't want to introduce breaking changes throughout the - // entire Spark API, I have to use this hacky approach: - def write(bytes: Array[Byte]) { + + def writeBytes(bytes: Array[Byte]) { if (bytes == null) { dataOut.writeInt(SpecialLengths.NULL) } else { @@ -384,7 +382,7 @@ private[spark] object PythonRDD extends Logging { } } - def writeS(str: String) { + def writeString(str: String) { if (str == null) { dataOut.writeInt(SpecialLengths.NULL) } else { @@ -392,42 +390,45 @@ private[spark] object PythonRDD extends Logging { } } + // The right way to implement this would be to use TypeTags to get the full + // type of T. Since I don't want to introduce breaking changes throughout the + // entire Spark API, I have to use this hacky approach: if (iter.hasNext) { val first = iter.next() val newIter = Seq(first).iterator ++ iter first match { case arr: Array[Byte] => - newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(write) + newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(writeBytes) case string: String => - newIter.asInstanceOf[Iterator[String]].foreach(writeS) + newIter.asInstanceOf[Iterator[String]].foreach(writeString) case stream: PortableDataStream => newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => - write(stream.toArray()) + writeBytes(stream.toArray()) } case (key: String, stream: PortableDataStream) => newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { case (key, stream) => - writeS(key) - write(stream.toArray()) + writeString(key) + writeBytes(stream.toArray()) } case (key: String, value: String) => newIter.asInstanceOf[Iterator[(String, String)]].foreach { case (key, value) => - writeS(key) - writeS(value) + writeString(key) + writeString(value) } case (key: Array[Byte], value: Array[Byte]) => newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { case (key, value) => - write(key) - write(value) + writeBytes(key) + writeBytes(value) } // key is null - case (null, v:Array[Byte]) => + case (null, value: Array[Byte]) => newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { case (key, value) => - write(key) - write(value) + writeBytes(key) + writeBytes(value) } case other => diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index df95ce9622573..4c930d45ee251 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -487,7 +487,7 @@ def loads(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError - if length == SpecialLengths.NULL: + elif length == SpecialLengths.NULL: return None s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s From 23b039a896497c8f4cae1bf963274ff295841c37 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 27 Jan 2015 22:56:42 -0800 Subject: [PATCH 15/18] address comments --- examples/src/main/python/streaming/kafka_wordcount.py | 2 +- python/pyspark/serializers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 58201f967664a..ed398a82b8bb0 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -21,7 +21,7 @@ To run this on your local machine, you need to setup Kafka and create a producer first, see http://kafka.apache.org/documentation.html#quickstart - + and then run the example `$ bin/spark-submit --driver-class-path external/kafka-assembly/target/scala-*/\ spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4c930d45ee251..2a84635e0c0a1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -146,7 +146,7 @@ def _read_with_length(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError - if length == SpecialLengths.NULL: + elif length == SpecialLengths.NULL: return None obj = stream.read(length) if len(obj) < length: From f257071feeec5c10ce953cf2163a1aa116526bbd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 29 Jan 2015 12:08:26 -0800 Subject: [PATCH 16/18] add tests for null in RDD --- .../apache/spark/api/python/PythonRDD.scala | 71 ++++--------------- .../apache/spark/api/python/PythonUtils.scala | 5 ++ .../spark/api/python/PythonRDDSuite.scala | 22 ++++-- python/pyspark/serializers.py | 2 + python/pyspark/streaming/kafka.py | 13 ++-- python/pyspark/tests.py | 8 ++- 6 files changed, 52 insertions(+), 69 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ff5f7a0e0d3fc..a13a1b923c5f6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -373,68 +373,27 @@ private[spark] object PythonRDD extends Logging { def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { - def writeBytes(bytes: Array[Byte]) { - if (bytes == null) { + def write(obj: Any): Unit = obj match { + case null => dataOut.writeInt(SpecialLengths.NULL) - } else { - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - } - def writeString(str: String) { - if (str == null) { - dataOut.writeInt(SpecialLengths.NULL) - } else { + case arr: Array[Byte] => + dataOut.writeInt(arr.length) + dataOut.write(arr) + case str: String => writeUTF(str, dataOut) - } - } - // The right way to implement this would be to use TypeTags to get the full - // type of T. Since I don't want to introduce breaking changes throughout the - // entire Spark API, I have to use this hacky approach: - if (iter.hasNext) { - val first = iter.next() - val newIter = Seq(first).iterator ++ iter - first match { - case arr: Array[Byte] => - newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(writeBytes) - case string: String => - newIter.asInstanceOf[Iterator[String]].foreach(writeString) - case stream: PortableDataStream => - newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => - writeBytes(stream.toArray()) - } - case (key: String, stream: PortableDataStream) => - newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { - case (key, stream) => - writeString(key) - writeBytes(stream.toArray()) - } - case (key: String, value: String) => - newIter.asInstanceOf[Iterator[(String, String)]].foreach { - case (key, value) => - writeString(key) - writeString(value) - } - case (key: Array[Byte], value: Array[Byte]) => - newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { - case (key, value) => - writeBytes(key) - writeBytes(value) - } - // key is null - case (null, value: Array[Byte]) => - newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { - case (key, value) => - writeBytes(key) - writeBytes(value) - } + case stream: PortableDataStream => + write(stream.toArray()) + case (key, value) => + write(key) + write(value) - case other => - throw new SparkException("Unexpected element type " + other.getClass) - } + case other => + throw new SparkException("Unexpected element type " + other.getClass) } + + iter.foreach(write) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index be5ebfa9219d3..b7cfc8bd9c542 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -22,6 +22,7 @@ import java.io.{File, InputStream, IOException, OutputStream} import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext +import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} private[spark] object PythonUtils { /** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */ @@ -39,4 +40,8 @@ private[spark] object PythonUtils { def mergePythonPaths(paths: String*): String = { paths.filter(_ != "").mkString(File.pathSeparator) } + + def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = { + sc.parallelize(List("a", null, "b")) + } } diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index 7b866f08a0e9f..f4cf02977033e 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -23,11 +23,21 @@ import org.scalatest.FunSuite class PythonRDDSuite extends FunSuite { - test("Writing large strings to the worker") { - val input: List[String] = List("a"*100000) - val buffer = new DataOutputStream(new ByteArrayOutputStream) - PythonRDD.writeIteratorToStream(input.iterator, buffer) - } + test("Writing large strings to the worker") { + val input: List[String] = List("a"*100000) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } -} + test("Handle nulls gracefully") { + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(List("a", null).iterator, buffer) + PythonRDD.writeIteratorToStream(List(null, "a").iterator, buffer) + PythonRDD.writeIteratorToStream(List("a".getBytes, null).iterator, buffer) + PythonRDD.writeIteratorToStream(List(null, "a".getBytes).iterator, buffer) + PythonRDD.writeIteratorToStream(List((null, null), ("a", null), (null, "b")).iterator, buffer) + PythonRDD.writeIteratorToStream( + List((null, null), ("a".getBytes, null), (null, "b".getBytes)).iterator, buffer) + } +} diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2a84635e0c0a1..0ffb41d02f6f6 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -134,6 +134,8 @@ def load_stream(self, stream): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) + if serialized is None: + raise ValueError("serialized value should not be None") if len(serialized) > (1 << 31): raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index ff1a951b6f11b..19ad71f99d4d5 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -33,7 +33,7 @@ def utf8_decoder(s): class KafkaUtils(object): @staticmethod - def createStream(ssc, zkQuorum, groupId, topics, + def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ @@ -44,22 +44,23 @@ def createStream(ssc, zkQuorum, groupId, topics, :param groupId: The group id for this consumer. :param topics: Dict of (topic_name -> numPartitions) to consume. Each partition is consumed in its own thread. + :param kafkaParams: Additional params for Kafka :param storageLevel: RDD storage level. - :param keyDecoder: A function used to decode key - :param valueDecoder: A function used to decode value + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object """ java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils") - param = { + kafkaParams.update({ "zookeeper.connect": zkQuorum, "group.id": groupId, "zookeeper.connection.timeout.ms": "10000", - } + }) if not isinstance(topics, dict): raise TypeError("topics should be dict") jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) - jparam = MapConverter().convert(param, ssc.sparkContext._gateway._gateway_client) + jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) def getClassByName(name): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b474fcf5bfb7e..56c36bef753d4 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -46,9 +46,10 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext +from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer + CloudPickleSerializer, CompressedSerializer, UTF8Deserializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -714,6 +715,11 @@ def test_sample(self): wr_s21 = rdd.sample(True, 0.4, 21).collect() self.assertNotEqual(set(wr_s11), set(wr_s21)) + def test_null_in_rdd(self): + jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) + rdd = RDD(jrdd, self.sc, UTF8Deserializer()) + self.assertEqual([u"a", None, u"b"], rdd.collect()) + class ProfilerTests(PySparkTestCase): From 4280d04a32c69024bd200e407275a123b8373035 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 29 Jan 2015 17:25:31 -0800 Subject: [PATCH 17/18] address comments --- .../org/apache/spark/api/python/PythonRDD.scala | 3 --- .../apache/spark/api/python/PythonRDDSuite.scala | 15 ++++++++------- python/pyspark/tests.py | 4 +++- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a13a1b923c5f6..3308f155ccf2e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -376,19 +376,16 @@ private[spark] object PythonRDD extends Logging { def write(obj: Any): Unit = obj match { case null => dataOut.writeInt(SpecialLengths.NULL) - case arr: Array[Byte] => dataOut.writeInt(arr.length) dataOut.write(arr) case str: String => writeUTF(str, dataOut) - case stream: PortableDataStream => write(stream.toArray()) case (key, value) => write(key) write(value) - case other => throw new SparkException("Unexpected element type " + other.getClass) } diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index f4cf02977033e..c63d834f9048b 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -31,13 +31,14 @@ class PythonRDDSuite extends FunSuite { test("Handle nulls gracefully") { val buffer = new DataOutputStream(new ByteArrayOutputStream) - PythonRDD.writeIteratorToStream(List("a", null).iterator, buffer) - PythonRDD.writeIteratorToStream(List(null, "a").iterator, buffer) - PythonRDD.writeIteratorToStream(List("a".getBytes, null).iterator, buffer) - PythonRDD.writeIteratorToStream(List(null, "a".getBytes).iterator, buffer) - - PythonRDD.writeIteratorToStream(List((null, null), ("a", null), (null, "b")).iterator, buffer) + // Should not have NPE when write an Iterator with null in it + // The correctness will be tested in Python + PythonRDD.writeIteratorToStream(Iterator("a", null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer) + PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer) + PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer) PythonRDD.writeIteratorToStream( - List((null, null), ("a".getBytes, null), (null, "b".getBytes)).iterator, buffer) + Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer) } } diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index d001b39d615de..9f07bd49d5fd8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,7 @@ from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer + CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -720,6 +720,8 @@ def test_null_in_rdd(self): jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) rdd = RDD(jrdd, self.sc, UTF8Deserializer()) self.assertEqual([u"a", None, u"b"], rdd.collect()) + rdd = RDD(jrdd, self.sc, NoOpSerializer()) + self.assertEqual(["a", None, "b"], rdd.collect()) def test_multiple_python_java_RDD_conversions(self): # Regression test for SPARK-5361 From d93bfe01486256e8f3618928cf0670a1b2ce59b7 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 2 Feb 2015 18:11:51 -0800 Subject: [PATCH 18/18] Update make-distribution.sh --- make-distribution.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/make-distribution.sh b/make-distribution.sh index 8313ca047c030..051c87c0894ae 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -196,7 +196,6 @@ echo "Build flags: $@" >> "$DISTDIR/RELEASE" # Copy jars cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$SPARK_HOME"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" -cp "$SPARK_HOME"/external/kafka/scala*/*kafka*assembly*.jar "$DISTDIR/lib/" # This will fail if the -Pyarn profile is not provided # In this case, silence the error and ignore the return code of this command cp "$SPARK_HOME"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || :