Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1630] Turn Null of Java/Scala into None of Python #1551

Closed
wants to merge 8 commits into from
Closed
55 changes: 35 additions & 20 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.net._
import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}

import scala.annotation.tailrec
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
import scala.util.Try
Expand Down Expand Up @@ -270,9 +271,10 @@ private object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
val NULL = -4
}

private[spark] object PythonRDD extends Logging {
private[spark] object PythonRDD {
val UTF8 = Charset.forName("UTF-8")

/**
Expand Down Expand Up @@ -312,42 +314,51 @@ private[spark] object PythonRDD extends Logging {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}

@tailrec
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the
// entire Spark API, I have to use this hacky approach:
def writeBytes(bytes: Array[Byte]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a legitimate case where a Iterator[Array[Byte]] will contain a null? I was hoping we'd only have to worry about nulls in Iterator[String].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array[Byte] is similar to String, null can be generated by user's functions or RDDs, just like

RDD[String].map(x => if (x != null) x.toArray else x)

if (bytes == null) {
dataOut.writeInt(SpecialLengths.NULL)
} else {
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
}
if (iter.hasNext) {
val first = iter.next()
val newIter = Seq(first).iterator ++ iter
first match {
case arr: Array[Byte] =>
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes =>
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { writeBytes(_) }
case string: String =>
newIter.asInstanceOf[Iterator[String]].foreach { str =>
writeUTF(str, dataOut)
}
newIter.asInstanceOf[Iterator[String]].foreach { writeUTF(_, dataOut) }
case pair: Tuple2[_, _] =>
pair._1 match {
case bytePair: Array[Byte] =>
newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
dataOut.writeInt(pair._1.length)
dataOut.write(pair._1)
dataOut.writeInt(pair._2.length)
dataOut.write(pair._2)
newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach {
case (k, v) =>
writeBytes(k)
writeBytes(v)
}
case stringPair: String =>
newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
writeUTF(pair._1, dataOut)
writeUTF(pair._2, dataOut)
newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach {
case (k, v) =>
writeUTF(k, dataOut)
writeUTF(v, dataOut)
}
case other =>
throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
}
case other =>
throw new SparkException("Unexpected element type " + first.getClass)
if (other == null) {
dataOut.writeInt(SpecialLengths.NULL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it doesn't matter much here, but would it make sense to write a byte instead of an int?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's header of var-length field, it's better to keep this header has fixed length, or you will need to deal with special var-length encoding.

writeIteratorToStream(iter, dataOut)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method isn't tail-recursive, so this will cause a StackOverflow if you try to write an iterator with thousands of consecutive nulls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we only have to worry about nulls when writing iterators from user-defined RDDs of strings. So, if we see an iterator that begins with null, we can assume that the remainder of the iterator contains only nulls or strings. Therefore, I think you can write out the first null followed by

iter.asInstanceOf[Iterator[String]].foreach { str =>
  writeUTF(str, dataOut)
}

to process the remainder of the stream.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong; this is tail-recursive. If we only expect nulls to occur in iterators of strings, then I think we should be able to remove the null checking here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to handle NPE as much as possible, until you can prove that NPE will not happen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is what I didn't understand about the whole PR: user code is not meant to call PythonRDD directly. Note that the whole PythonRDD object is private[spark]. So where in the codebase today can we get nulls there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but that's a private API, it doesn't matter. Does our own code do it?

Basically I'm worried that this significantly complicates our code for something that shouldn't happen. I'd rather have an NPE if our own code later passes nulls here (cause it really shouldn't be doing that since we control everything we pass in).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If users want to call UDF in Java/Scala from PySpark, they have to use this private API to do it, so it's possible to have null in RDD[string] or RDD[Array[Byte]].

BTW, it will be helpful if we can skip some BAD rows during map/reduce, which was mentioned in MapReduce paper. This is not MUST have feature, but it really improve the robustness of whole framework, very useful for large scale jobs.

This PR try to improve the stability of PySpark, let users feel safer and happier to hack in PySpark.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, sorry, I don't think this improves stability:

  1. Users are not supposed to call private APIs. In fact even Scala code can't call PythonRDD because that is private[spark] -- it's just an artifact of the way Scala implements package-private that the class becomes public in Java. If you'd like support for UDFs we need to add that as a separate, top-level feature.
  2. This change would mask bugs in the current way we write Python converters. Our current converters only pass in Strings and arrays of bytes, which shouldn't be null. (For datasets that contain null they convert it to a picked form of None already). This means that if someone introduces a bug in one of our existing code paths, that bug will be harder to fix because instead of being an NPE, it will be some weird value coming out in Python.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW apart from the stability issue above with catching our own bugs, the reason I'm commenting is that this change also adds some moderately tricky code in a fairly important code path, increasing the chance of adding new bugs. That doesn't seem worth it to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's hold it.

} else {
throw new SparkException("Unexpected element type " + first.getClass)
}
}
}
}
Expand Down Expand Up @@ -527,9 +538,13 @@ private[spark] object PythonRDD extends Logging {
}

def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes(UTF8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
if (str == null) {
dataOut.writeInt(SpecialLengths.NULL)
} else {
val bytes = str.getBytes(UTF8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
}

def writeToFile[T](items: java.util.Iterator[T], filename: String) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,21 @@ import org.scalatest.FunSuite

class PythonRDDSuite extends FunSuite {

test("Writing large strings to the worker") {
val input: List[String] = List("a"*100000)
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer)
}
test("Writing large strings to the worker") {
val input: List[String] = List("a" * 100000)
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer)
}

}
test("Handle nulls gracefully") {
val input: List[String] = List("a", null)
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer)
}

test("Handle list starts with nulls gracefully") {
val input: List[String] = List(null, null, "a", null)
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer)
}
}
3 changes: 3 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class SpecialLengths(object):
END_OF_DATA_SECTION = -1
PYTHON_EXCEPTION_THROWN = -2
TIMING_DATA = -3
NULL = -4


class Serializer(object):
Expand Down Expand Up @@ -336,6 +337,8 @@ class UTF8Deserializer(Serializer):

def loads(self, stream):
length = read_int(stream)
if length == SpecialLengths.NULL:
return None
return stream.read(length).decode('utf8')

def load_stream(self, stream):
Expand Down