diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 6485ebb8081a4..def8bdc8e318d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -17,64 +17,75 @@ package org.apache.spark.examples.pythonconverters -import java.nio.ByteBuffer +import java.util.{Collection => JCollection, Map => JMap} import scala.collection.JavaConversions._ -import org.apache.avro.generic.{GenericArray, GenericFixed, GenericRecord} +import org.apache.avro.generic.{GenericFixed, IndexedRecord} import org.apache.avro.mapred.AvroWrapper import org.apache.avro.Schema import org.apache.avro.Schema.Type._ -import org.apache.avro.util.Utf8 import org.apache.spark.api.python.Converter import org.apache.spark.SparkException /** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts an - * Avro Wrapper to Java object. It only works with Avro's Generic Data Model. + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries + * to work with all 3 Avro data models. */ class AvroWrapperToJavaConverter extends Converter[Any, Any] { override def convert(obj: Any): Any = obj.asInstanceOf[AvroWrapper[_]].datum() match { - case record: GenericRecord => unpackRecord(record, None) + case record: IndexedRecord => unpackRecord(record) case other => throw new SparkException( s"Unsupported top-level Avro data type ${other.getClass.getName}") } - def unpackRecord(obj: Any, readerSchema: Option[Schema]): java.util.Map[String, Any] = { - val record = obj.asInstanceOf[GenericRecord] + def unpackRecord(obj: Any): JMap[String, Any] = { val map = new java.util.HashMap[String, Any] - readerSchema.getOrElse(record.getSchema).getFields.foreach { case f => - map.put(f.name, fromAvro(record.get(f.name), f.schema)) + obj match { + case record: IndexedRecord => + record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + map.put(f.name, fromAvro(record.get(i), f.schema)) + } + case other => throw new SparkException( + s"Unsupported RECORD type ${other.getClass.getName}") } map } - def unpackMap(obj: Any, schema: Schema): java.util.Map[String, Any] = { - val map = new java.util.HashMap[String, Any] - obj.asInstanceOf[Map[Utf8, _]].foreach { case (key, value) => - map.put(key.toString, fromAvro(value, schema.getValueType)) + def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { + obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + (key.toString, fromAvro(value, schema.getValueType)) } - map } def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { unpackBytes(obj.asInstanceOf[GenericFixed].bytes()) } - def unpackBytes(bytes: Array[Byte]): Array[Byte] = { + def unpackBytes(obj: Any): Array[Byte] = { + val bytes: Array[Byte] = obj match { + case buf: java.nio.ByteBuffer => buf.array() + case arr: Array[Byte] => arr + case other => throw new SparkException( + s"Unknown BYTES type ${other.getClass.getName}") + } val bytearray = new Array[Byte](bytes.length) System.arraycopy(bytes, 0, bytearray, 0, bytes.length) bytearray } - def unpackArray(obj: Any, schema: Schema): java.util.List[Any] = { - val list = new java.util.ArrayList[Any] - obj.asInstanceOf[GenericArray[_]].foreach { element => - list.add(fromAvro(element, schema.getElementType)) - } - list + def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { + case c: JCollection[_] => + c.map(fromAvro(_, schema.getElementType)) + case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => + arr.toSeq + case arr: Array[_] => + arr.map(fromAvro(_, schema.getElementType)).toSeq + case other => throw new SparkException( + s"Unknown ARRAY type ${other.getClass.getName}") } def unpackUnion(obj: Any, schema: Schema): Any = { @@ -95,9 +106,9 @@ class AvroWrapperToJavaConverter extends Converter[Any, Any] { case UNION => unpackUnion(obj, schema) case ARRAY => unpackArray(obj, schema) case FIXED => unpackFixed(obj, schema) - case BYTES => unpackBytes(obj.asInstanceOf[ByteBuffer].array()) - case RECORD => unpackRecord(obj, Option(schema)) case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) case STRING => obj.toString case ENUM => obj.toString case NULL => obj @@ -107,7 +118,7 @@ class AvroWrapperToJavaConverter extends Converter[Any, Any] { case INT => obj case LONG => obj case other => throw new SparkException( - s"Unsupported Avro schema type ${other.getName}") + s"Unknown Avro schema type ${other.getName}") } } }