diff --git a/src/scala/microsoft-spark-3-3/pom.xml b/src/scala/microsoft-spark-3-3/pom.xml new file mode 100644 index 000000000..ddfbab676 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/pom.xml @@ -0,0 +1,83 @@ + + 4.0.0 + + com.microsoft.scala + microsoft-spark + ${microsoft-spark.version} + + microsoft-spark-3-3_2.12 + 2019 + + UTF-8 + 2.12.10 + 2.12 + 3.3.0 + + + + + org.scala-lang + scala-library + ${scala.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${spark.version} + provided + + + junit + junit + 4.13.1 + test + + + org.specs + specs + 1.2.5 + test + + + + + src/main/scala + src/test/scala + + + org.scala-tools + maven-scala-plugin + 2.15.2 + + + + compile + testCompile + + + + + ${scala.version} + + -target:jvm-1.8 + -deprecation + -feature + + + + + + diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala new file mode 100644 index 000000000..aea355dfa --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.DataOutputStream + +import org.apache.spark.internal.Logging + +import scala.collection.mutable.Queue + +/** + * CallbackClient is used to communicate with the Dotnet CallbackServer. + * The client manages and maintains a pool of open CallbackConnections. + * Any callback request is delegated to a new CallbackConnection or + * unused CallbackConnection. + * @param address The address of the Dotnet CallbackServer + * @param port The port of the Dotnet CallbackServer + */ +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { + private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() + + private[this] var isShutdown: Boolean = false + + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = + getOrCreateConnection() match { + case Some(connection) => + try { + connection.send(callbackId, writeBody) + addConnection(connection) + } catch { + case e: Exception => + logError(s"Error calling callback [callback id = $callbackId].", e) + connection.close() + throw e + } + case None => throw new Exception("Unable to get or create connection.") + } + + private def getOrCreateConnection(): Option[CallbackConnection] = synchronized { + if (isShutdown) { + logInfo("Cannot get or create connection while client is shutdown.") + return None + } + + if (connectionPool.nonEmpty) { + return Some(connectionPool.dequeue()) + } + + Some(new CallbackConnection(serDe, address, port)) + } + + private def addConnection(connection: CallbackConnection): Unit = synchronized { + assert(connection != null) + connectionPool.enqueue(connection) + } + + def shutdown(): Unit = synchronized { + if (isShutdown) { + logInfo("Shutdown called, but already shutdown.") + return + } + + logInfo("Shutting down.") + connectionPool.foreach(_.close) + connectionPool.clear + isShutdown = true + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala new file mode 100644 index 000000000..604cf029b --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.internal.Logging + +/** + * CallbackConnection is used to process the callback communication + * between the JVM and Dotnet. It uses a TCP socket to communicate with + * the Dotnet CallbackServer and the socket is expected to be reused. + * @param address The address of the Dotnet CallbackServer + * @param port The port of the Dotnet CallbackServer + */ +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { + private[this] val socket: Socket = new Socket(address, port) + private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) + private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) + + def send( + callbackId: Int, + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { + logInfo(s"Calling callback [callback id = $callbackId] ...") + + try { + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) + + val byteArrayOutputStream = new ByteArrayOutputStream() + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) + byteArrayOutputStream.writeTo(outputStream); + } catch { + case e: Exception => { + throw new Exception("Error writing to stream.", e) + } + } + + logInfo(s"Signaling END_OF_STREAM.") + try { + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + outputStream.flush() + + val endOfStreamResponse = readFlag(inputStream) + endOfStreamResponse match { + case CallbackFlags.END_OF_STREAM => + logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.") + case _ => { + throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " + + s"Received: $endOfStreamResponse") + } + } + } catch { + case e: Exception => { + throw new Exception("Error while verifying end of stream.", e) + } + } + } + + def close(): Unit = { + try { + serDe.writeInt(outputStream, CallbackFlags.CLOSE) + outputStream.flush() + } catch { + case e: Exception => logInfo("Unable to send close to .NET callback server.", e) + } + + close(socket) + close(outputStream) + close(inputStream) + } + + private def close(s: Socket): Unit = { + try { + assert(s != null) + s.close() + } catch { + case e: Exception => logInfo("Unable to close socket.", e) + } + } + + private def close(c: Closeable): Unit = { + try { + assert(c != null) + c.close() + } catch { + case e: Exception => logInfo("Unable to close closeable.", e) + } + } + + private def readFlag(inputStream: DataInputStream): Int = { + val callbackFlag = serDe.readInt(inputStream) + if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { + val exceptionMessage = serDe.readString(inputStream) + throw new DotnetException(exceptionMessage) + } + callbackFlag + } + + private object CallbackFlags { + val CLOSE: Int = -1 + val CALLBACK: Int = -2 + val DOTNET_EXCEPTION_THROWN: Int = -3 + val END_OF_STREAM: Int = -4 + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala new file mode 100644 index 000000000..c6f528aee --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.net.InetSocketAddress +import java.util.concurrent.TimeUnit +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS +import org.apache.spark.{SparkConf, SparkEnv} + +/** + * Netty server that invokes JVM calls based upon receiving messages from .NET. + * The implementation mirrors the RBackend. + * + */ +class DotnetBackend extends Logging { + self => // for accessing the this reference in inner class(ChannelInitializer) + private[this] var channelFuture: ChannelFuture = _ + private[this] var bootstrap: ServerBootstrap = _ + private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None + + def init(portNumber: Int): Int = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS) + logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.") + bossGroup = new NioEventLoopGroup(numBackendThreads) + val workerGroup = bossGroup + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast( + "frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort + } + + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + callbackClient = None + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.config().childGroup() != null) { + bootstrap.config().childGroup().shutdownGracefully() + } + bootstrap = null + + objectTracker.clear() + + // Send close to .NET callback server. + shutdownCallbackClient() + + // Shutdown the thread pool whose executors could still be running. + ThreadPool.shutdown() + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala new file mode 100644 index 000000000..fc3a5c911 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + +/** + * Handler for DotnetBackend. + * This implementation is similar to RBackendHandler. + */ +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) + extends SimpleChannelInboundHandler[Array[Byte]] + with Logging { + + private[this] val serDe = new SerDe(objectsTracker) + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val reply = handleBackendRequest(msg) + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + def handleBackendRequest(msg: Array[Byte]): Array[Byte] = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = serDe.readBoolean(dis) + val processId = serDe.readInt(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) + + if (objId == "DotnetHandler") { + methodName match { + case "stopBackend" => + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") + server.close() + case "rm" => + try { + val t = serDe.readObjectType(dis) + assert(t == 'c') + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + serDe.writeInt(dos, -1) + } + case "rmThread" => + try { + assert(serDe.readObjectType(dis) == 'i') + val processId = serDe.readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) + val result = ThreadPool.tryDeleteThread(processId, threadToDelete) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + serDe.writeInt(dos, -1) + } + case "connectCallback" => + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + + // Sends reference of CallbackClient to dotnet side, + // so that dotnet process can send the client back to Java side + // when calling any API containing callback functions. + serDe.writeObject(dos, server.callbackClient) + case "closeCallback" => + logInfo("Requesting to close callback client") + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") + case _ => dos.writeInt(-1) + } + } else { + ThreadPool + .run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) + } + + bos.toByteArray + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Skip logging the exception message if the connection was disconnected from + // the .NET side so that .NET side doesn't have to explicitly close the connection via + // "stopBackend." Note that an exception is still thrown if the exit status is non-zero, + // so skipping this kind of exception message does not affect the debugging. + if (!cause.getMessage.contains( + "An existing connection was forcibly closed by the remote host")) { + logError("Exception caught: ", cause) + } + + // Close the connection when an exception is raised. + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + var args: Array[java.lang.Object] = null + var methods: Array[java.lang.reflect.Method] = null + + try { + val cls = if (isStatic) { + Utils.classForName(objId) + } else { + objectsTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + args = readArgs(numArgs, dis) + methods = cls.getMethods + + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args) + + if (index.isEmpty) { + logWarning( + s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + + val ret = selectedMethods(index.get).invoke(obj, args: _*) + + // Write status bit + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args: _*) + + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException( + "invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Throwable => + val jvmObj = objectsTracker.get(objId) + val jvmObjName = jvmObj match { + case Some(jObj) => jObj.getClass.getName + case None => "NullObject" + } + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") + + logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") + + if (methods != null) { + logDebug(s"All methods for $jvmObjName:") + methods.foreach(m => logDebug(m.toString)) + } + + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + serDe.readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 until numArgs) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + + if (!parameterWrapperType.isInstance(args(i))) { + // non primitive types + if (!parameterType.isPrimitive && args(i) != null) { + return false + } + + // primitive types + if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) { + return false + } + } + } + + true + } + + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- parameterTypesOfMethods.indices) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if ((parameterType.isPrimitive || args(i) != null) && + !parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } + + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + for (i <- 0 until numArgs) { + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) + } + } + } + None + } + + def logError(id: String, e: Exception): Unit = {} +} + + diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala new file mode 100644 index 000000000..c70d16b03 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala @@ -0,0 +1,13 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +class DotnetException(message: String, cause: Throwable) + extends Exception(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala new file mode 100644 index 000000000..f5277c215 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD + +object DotnetRDD { + def createPythonRDD( + parent: RDD[_], + func: PythonFunction, + preservePartitoning: Boolean): PythonRDD = { + new PythonRDD(parent, func, preservePartitoning) + } + + def createJavaRDDFromArray( + sc: SparkContext, + arr: Array[Array[Byte]], + numSlices: Int): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(sc.parallelize(arr, numSlices)) + } + + def toJavaRDD(rdd: RDD[_]): JavaRDD[_] = JavaRDD.fromRDD(rdd) +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala new file mode 100644 index 000000000..9f556338b --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import scala.collection.JavaConverters._ + +/** DotnetUtils object that hosts some helper functions + * help data type conversions between dotnet and scala + */ +object DotnetUtils { + + /** A helper function to convert scala Map to java.util.Map + * @param value - scala Map + * @return java.util.Map + */ + def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava + + /** Convert java data type to corresponding scala type + * @param value - java.lang.Object + * @return Any + */ + def mapScalaToJava(value: java.lang.Object): Any = { + value match { + case i: java.lang.Integer => i.toInt + case d: java.lang.Double => d.toDouble + case f: java.lang.Float => f.toFloat + case b: java.lang.Boolean => b.booleanValue() + case l: java.lang.Long => l.toLong + case s: java.lang.Short => s.toShort + case by: java.lang.Byte => by.toByte + case c: java.lang.Character => c.toChar + case _ => value + } + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..81cfaf88b --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala new file mode 100644 index 000000000..06a476f67 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import org.apache.spark.SparkConf + +/* + * Utils for JvmBridge. + */ +object JvmBridgeUtils { + def getKeyValuePairAsString(kvp: (String, String)): String = { + return kvp._1 + "=" + kvp._2 + } + + def getKeyValuePairArrayAsString(kvpArray: Array[(String, String)]): String = { + val sb = new StringBuilder + + for (kvp <- kvpArray) { + sb.append(getKeyValuePairAsString(kvp)) + sb.append(";") + } + + sb.toString + } + + def getSparkConfAsString(sparkConf: SparkConf): String = { + getKeyValuePairArrayAsString(sparkConf.getAll) + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala new file mode 100644 index 000000000..a3df3788a --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.{DataInputStream, DataOutputStream} +import java.nio.charset.StandardCharsets +import java.sql.{Date, Time, Timestamp} + +import org.apache.spark.sql.Row + +import scala.collection.JavaConverters._ + +/** + * Class responsible for serialization and deserialization between CLR & JVM. + * This implementation of methods is mostly identical to the SerDe implementation in R. + */ +class SerDe(val tracker: JVMObjectTracker) { + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'g' => new java.lang.Long(readLong(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => tracker.getObject(readString(dis)) + case 'R' => readRowArr(dis) + case 'O' => readObjectArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + private def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + private def readLong(in: DataInputStream): Long = { + in.readLong() + } + + private def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + private def readStringBytes(in: DataInputStream, len: Int): String = { + val bytes = new Array[Byte](len) + in.readFully(bytes) + val str = new String(bytes, "UTF-8") + str + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + readStringBytes(in, len) + } + + def readBoolean(in: DataInputStream): Boolean = { + in.readBoolean() + } + + private def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + private def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t + } + + private def readRow(in: DataInputStream): Row = { + val len = readInt(in) + Row.fromSeq((0 until len).map(_ => readObject(in))) + } + + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + private def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + private def readLongArr(in: DataInputStream): Array[Long] = { + val len = readInt(in) + (0 until len).map(_ => readLong(in)).toArray + } + + private def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + val len = readInt(in) + (0 until len).map(_ => readDoubleArr(in)).toArray + } + + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + private def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + private def readRowArr(in: DataInputStream): java.util.List[Row] = { + val len = readInt(in) + (0 until len).map(_ => readRow(in)).toList.asJava + } + + private def readObjectArr(in: DataInputStream): Seq[Any] = { + val len = readInt(in) + (0 until len).map(_ => readObject(in)) + } + + private def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'g' => readLongArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'A' => readDoubleArrArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) + keys.zip(values).toMap.asJava + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Using the same mapping as SparkR implementation for now + // Methods to write out data from Java to .NET. + // + // Type mapping from Java to .NET: + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Float -> double + // Double -> double + // Long -> long + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "doublearray" => dos.writeByte('A') + case "long" => dos.writeByte('g') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null || value == Unit) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "long" | "java.lang.Long" => + writeType(dos, "long") + writeLong(dos, value.asInstanceOf[Long]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeLongArr(dos, value.asInstanceOf[Array[Long]]) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[[D" => + writeType(dos, "list") + writeDoubleArrArr(dos, value.asInstanceOf[Array[Array[Double]]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeLong(out: DataOutputStream, value: Long): Unit = { + out.writeLong(value) + } + + private def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + out.writeBoolean(value) + } + + private def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + private def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } + + def writeString(out: DataOutputStream, value: String): Unit = { + val utf8 = value.getBytes(StandardCharsets.UTF_8) + val len = utf8.length + out.writeInt(len) + out.write(utf8, 0, len) + } + + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = tracker.put(value) + writeString(out, objId) + } + + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + writeType(out, "long") + out.writeInt(value.length) + value.foreach(v => out.writeLong(v)) + } + + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + writeType(out, "doublearray") + out.writeInt(value.length) + value.foreach(v => writeDoubleArr(out, v)) + } + + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..50551a7d9 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors} + +import scala.collection.mutable + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from (processId, threadId) to corresponding executor. + */ + private val executors: mutable.HashMap[(Int, Int), ExecutorService] = + new mutable.HashMap[(Int, Int), ExecutorService]() + + /** + * Run some code on a particular thread. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. + */ + def run(processId: Int, threadId: Int, task: () => Unit): Unit = { + val executor = getOrCreateExecutor(processId, threadId) + val future = executor.submit(new Runnable { + override def run(): Unit = task() + }) + + future.get() + } + + /** + * Try to delete a particular thread. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. + */ + def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized { + executors.remove((processId, threadId)) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } + } + + /** + * Shutdown any running ExecutorServices. + */ + def shutdown(): Unit = synchronized { + executors.foreach(_._2.shutdown()) + executors.clear() + } + + /** + * Get the executor if it exists, otherwise create a new one. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @return The new or existing executor with the given id. + */ + private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized { + executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor) + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala new file mode 100644 index 000000000..4551a70bd --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.deploy.dotnet + +import org.apache.spark.SparkException + +/** + * This exception type describes an exception thrown by a .NET user application. + * + * @param exitCode Exit code returned by the .NET application. + * @param dotNetStackTrace Stacktrace extracted from .NET application logs. + */ +private[spark] class DotNetUserAppException(exitCode: Int, dotNetStackTrace: Option[String]) + extends SparkException( + dotNetStackTrace match { + case None => s"User application exited with $exitCode" + case Some(e) => s"User application exited with $exitCode and .NET exception: $e" + }) diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala new file mode 100644 index 000000000..894f21c21 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala @@ -0,0 +1,309 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.deploy.dotnet + +import java.io.File +import java.net.URI +import java.nio.file.attribute.PosixFilePermissions +import java.nio.file.{FileSystems, Files, Paths} +import java.util.Locale +import java.util.concurrent.{Semaphore, TimeUnit} + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.output.TeeOutputStream +import org.apache.hadoop.fs.Path +import org.apache.spark +import org.apache.spark.api.dotnet.DotnetBackend +import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.{ + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK, + ERROR_BUFFER_SIZE, ERROR_REDIRECITON_ENABLED +} +import org.apache.spark.util.dotnet.{Utils => DotnetUtils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.{SecurityManager, SparkConf, SparkUserAppException} + +import scala.collection.JavaConverters._ +import scala.io.StdIn +import scala.util.Try + +/** + * DotnetRunner class used to launch Spark .NET applications using spark-submit. + * It executes .NET application as a subprocess and then has it connect back to + * the JVM to access system properties etc. + */ +object DotnetRunner extends Logging { + private val DEBUG_PORT = 5567 + private val supportedSparkMajorMinorVersionPrefix = "3.3" + private val supportedSparkVersions = Set[String]("3.2.0", "3.2.1", "3.2.2", "3.2.3", "3.3.0") + + val SPARK_VERSION = DotnetUtils.normalizeSparkVersion(spark.SPARK_VERSION) + + def main(args: Array[String]): Unit = { + if (args.length == 0) { + throw new IllegalArgumentException("At least one argument is expected.") + } + + DotnetUtils.validateSparkVersions( + sys.props + .getOrElse( + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.key, + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.defaultValue.get.toString) + .toBoolean, + spark.SPARK_VERSION, + SPARK_VERSION, + supportedSparkMajorMinorVersionPrefix, + supportedSparkVersions) + + val settings = initializeSettings(args) + + // Determines if this needs to be run in debug mode. + // In debug mode this runner will not launch a .NET process. + val runInDebugMode = settings._1 + @volatile var dotnetBackendPortNumber = settings._2 + var dotnetExecutable = "" + var otherArgs: Array[String] = null + + if (!runInDebugMode) { + if (args(0).toLowerCase(Locale.ROOT).endsWith(".zip")) { + var zipFileName = args(0) + val zipFileUri = Try(new URI(zipFileName)).getOrElse(new File(zipFileName).toURI) + val workingDir = new File("").getAbsoluteFile + val driverDir = new File(workingDir, FilenameUtils.getBaseName(zipFileUri.getPath())) + + // Standalone cluster mode where .NET application is remotely located. + if (zipFileUri.getScheme() != "file") { + zipFileName = downloadDriverFile(zipFileName, workingDir.getAbsolutePath).getName + } + + logInfo(s"Unzipping .NET driver $zipFileName to $driverDir") + DotnetUtils.unzip(new File(zipFileName), driverDir) + + // Reuse windows-specific formatting in PythonRunner. + dotnetExecutable = PythonRunner.formatPath(resolveDotnetExecutable(driverDir, args(1))) + otherArgs = args.slice(2, args.length) + } else { + // Reuse windows-specific formatting in PythonRunner. + dotnetExecutable = PythonRunner.formatPath(args(0)) + otherArgs = args.slice(1, args.length) + } + } else { + otherArgs = args.slice(1, args.length) + } + + val processParameters = new java.util.ArrayList[String] + processParameters.add(dotnetExecutable) + otherArgs.foreach(arg => processParameters.add(arg)) + + logInfo(s"Starting DotnetBackend with $dotnetExecutable.") + + // Time to wait for DotnetBackend to initialize in seconds. + val backendTimeout = sys.env.getOrElse("DOTNETBACKEND_TIMEOUT", "120").toInt + + // Launch a DotnetBackend server for the .NET process to connect to; this will let it see our + // Java system properties etc. + val dotnetBackend = new DotnetBackend() + val initialized = new Semaphore(0) + val dotnetBackendThread = new Thread("DotnetBackend") { + override def run() { + // need to get back dotnetBackendPortNumber because if the value passed to init is 0 + // the port number is dynamically assigned in the backend + dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber) + logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber") + initialized.release() + dotnetBackend.run() + } + } + + dotnetBackendThread.start() + + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + if (!runInDebugMode) { + var returnCode = -1 + var process: Process = null + val enableLogRedirection: Boolean = sys.props + .getOrElse( + ERROR_REDIRECITON_ENABLED.key, + ERROR_REDIRECITON_ENABLED.defaultValue.get.toString).toBoolean + val stderrBuffer: Option[CircularBuffer] = Option(enableLogRedirection).collect { + case true => new CircularBuffer( + sys.props.getOrElse( + ERROR_BUFFER_SIZE.key, + ERROR_BUFFER_SIZE.defaultValue.get.toString).toInt) + } + + try { + val builder = new ProcessBuilder(processParameters) + val env = builder.environment() + env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString) + + for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { + env.put(key, value) + logInfo(s"Adding key=$key and value=$value to environment") + } + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + process = builder.start() + + // Redirect stdin of JVM process to stdin of .NET process. + new RedirectThread(System.in, process.getOutputStream, "redirect JVM input").start() + // Redirect stdout and stderr of .NET process to System.out and to buffer + // if log direction is enabled. If not, redirect only to System.out. + new RedirectThread( + process.getInputStream, + stderrBuffer match { + case Some(buffer) => new TeeOutputStream(System.out, buffer) + case _ => System.out + }, + "redirect .NET stdout and stderr").start() + + process.waitFor() + } catch { + case t: Throwable => + logThrowable(t) + } finally { + returnCode = closeDotnetProcess(process) + closeBackend(dotnetBackend) + } + if (returnCode != 0) { + if (stderrBuffer.isDefined) { + throw new DotNetUserAppException(returnCode, Some(stderrBuffer.get.toString)) + } else { + throw new SparkUserAppException(returnCode) + } + } else { + logInfo(s".NET application exited successfully") + } + // TODO: The following is causing the following error: + // INFO ApplicationMaster: Final app status: FAILED, exitCode: 16, + // (reason: Shutdown hook called before final status was reported.) + // DotnetUtils.exit(returnCode) + } else { + // scalastyle:off println + println("***********************************************************************") + println("* .NET Backend running debug mode. Press enter to exit *") + println("***********************************************************************") + // scalastyle:on println + + StdIn.readLine() + closeBackend(dotnetBackend) + DotnetUtils.exit(0) + } + } else { + logError(s"DotnetBackend did not initialize in $backendTimeout seconds") + DotnetUtils.exit(-1) + } + } + + // When the executable is downloaded as part of zip file, check if the file exists + // after zip file is unzipped under the given dir. Once it is found, change the + // permission to executable (only for Unix systems, since the zip file may have been + // created under Windows. Finally, the absolute path for the executable is returned. + private def resolveDotnetExecutable(dir: File, dotnetExecutable: String): String = { + val path = Paths.get(dir.getAbsolutePath, dotnetExecutable) + val resolvedExecutable = if (Files.isRegularFile(path)) { + path.toAbsolutePath.toString + } else { + Files + .walk(FileSystems.getDefault.getPath(dir.getAbsolutePath)) + .iterator() + .asScala + .find(path => Files.isRegularFile(path) && path.getFileName.toString == dotnetExecutable) match { + case Some(path) => path.toAbsolutePath.toString + case None => + throw new IllegalArgumentException( + s"Failed to find $dotnetExecutable under ${dir.getAbsolutePath}") + } + } + + if (DotnetUtils.supportPosix) { + Files.setPosixFilePermissions( + Paths.get(resolvedExecutable), + PosixFilePermissions.fromString("rwxr-xr-x")) + } + + resolvedExecutable + } + + /** + * Download HDFS file into the supplied directory and return its local path. + * Will throw an exception if there are errors during downloading. + */ + private def downloadDriverFile(hdfsFilePath: String, driverDir: String): File = { + val sparkConf = new SparkConf() + val filePath = new Path(hdfsFilePath) + + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val jarFileName = filePath.getName + val localFile = new File(driverDir, jarFileName) + + if (!localFile.exists()) { // May already exist if running multiple workers on one node + logInfo(s"Copying user file $filePath to $driverDir") + Utils.fetchFile( + hdfsFilePath, + new File(driverDir), + sparkConf, + hadoopConf, + System.currentTimeMillis(), + useCache = false) + } + + if (!localFile.exists()) { + throw new Exception(s"Did not see expected $jarFileName in $driverDir") + } + + localFile + } + + private def closeBackend(dotnetBackend: DotnetBackend): Unit = { + logInfo("Closing DotnetBackend") + dotnetBackend.close() + } + + private def closeDotnetProcess(dotnetProcess: Process): Int = { + if (dotnetProcess == null) { + return -1 + } else if (!dotnetProcess.isAlive) { + return dotnetProcess.exitValue() + } + + // Try to (gracefully on Linux) kill the process and resort to force if interrupted + var returnCode = -1 + logInfo("Closing .NET process") + try { + dotnetProcess.destroy() + returnCode = dotnetProcess.waitFor() + } catch { + case _: InterruptedException => + logInfo( + "Thread interrupted while waiting for graceful close. Forcefully closing .NET process") + returnCode = dotnetProcess.destroyForcibly().waitFor() + case t: Throwable => + logThrowable(t) + } + + returnCode + } + + private def initializeSettings(args: Array[String]): (Boolean, Int) = { + val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase( + "debug") + var portNumber = 0 + if (runInDebugMode) { + if (args.length == 1) { + portNumber = DEBUG_PORT + } else if (args.length == 2) { + portNumber = Integer.parseInt(args(1)) + } + } + + (runInDebugMode, portNumber) + } + + private def logThrowable(throwable: Throwable): Unit = + logError(s"${throwable.getMessage} \n ${throwable.getStackTrace.mkString("\n")}") +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala new file mode 100644 index 000000000..18ba4c6e5 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.internal.config.dotnet + +import org.apache.spark.internal.config.ConfigBuilder + +private[spark] object Dotnet { + val DOTNET_NUM_BACKEND_THREADS = ConfigBuilder("spark.dotnet.numDotnetBackendThreads").intConf + .createWithDefault(10) + + val DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK = + ConfigBuilder("spark.dotnet.ignoreSparkPatchVersionCheck").booleanConf + .createWithDefault(false) + + val ERROR_REDIRECITON_ENABLED = + ConfigBuilder("spark.nonjvm.error.forwarding.enabled").booleanConf + .createWithDefault(false) + + val ERROR_BUFFER_SIZE = + ConfigBuilder("spark.nonjvm.error.buffer.size") + .intConf + .checkValue(_ >= 0, "The error buffer size must not be negative") + .createWithDefault(10240) +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala new file mode 100644 index 000000000..3e3c3e0e3 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala @@ -0,0 +1,26 @@ + +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.mllib.api.dotnet + +import org.apache.spark.ml._ +import scala.collection.JavaConverters._ + +/** MLUtils object that hosts helper functions + * related to ML usage + */ +object MLUtils { + + /** A helper function to let pipeline accept java.util.ArrayList + * format stages in scala code + * @param pipeline - The pipeline to be set stages + * @param value - A java.util.ArrayList of PipelineStages to be set as stages + * @return The pipeline + */ + def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline = + pipeline.setStages(value.asScala.toArray) +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala new file mode 100644 index 000000000..5d06d4304 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import org.apache.spark.api.dotnet.CallbackClient +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.streaming.DataStreamWriter + +class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int) extends Logging { + def call(batchDF: DataFrame, batchId: Long): Unit = + callbackClient.send( + callbackId, + (dos, serDe) => { + serDe.writeJObj(dos, batchDF) + serDe.writeLong(dos, batchId) + }) +} + +object DotnetForeachBatchHelper { + def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = { + val dotnetForeachFunc = client match { + case Some(value) => new DotnetForeachBatchFunction(value, callbackId) + case None => throw new Exception("CallbackClient is null.") + } + + dsw.foreachBatch(dotnetForeachFunc.call _) + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala new file mode 100644 index 000000000..b5e97289a --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonFunction} +import org.apache.spark.broadcast.Broadcast + +object SQLUtils { + + /** + * Exposes createPythonFunction to the .NET client to enable registering UDFs. + */ + def createPythonFunction( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVersion: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: PythonAccumulatorV2): PythonFunction = { + + PythonFunction( + command, + envVars, + pythonIncludes, + pythonExec, + pythonVersion, + broadcastVars, + accumulator) + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala new file mode 100644 index 000000000..1cd45aa95 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/sql/test/TestUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.MemoryStream + +object TestUtils { + + /** + * Helper method to create typed MemoryStreams intended for use in unit tests. + * @param sqlContext The SQLContext. + * @param streamType The type of memory stream to create. This string is the `Name` + * property of the dotnet type. + * @return A typed MemoryStream. + */ + def createMemoryStream(implicit sqlContext: SQLContext, streamType: String): MemoryStream[_] = { + import sqlContext.implicits._ + + streamType match { + case "Int32" => MemoryStream[Int] + case "String" => MemoryStream[String] + case _ => throw new Exception(s"$streamType not supported") + } + } +} diff --git a/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala new file mode 100644 index 000000000..f9400789a --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/main/scala/org/apache/spark/util/dotnet/Utils.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.util.dotnet + +import java.io._ +import java.nio.file.attribute.PosixFilePermission +import java.nio.file.attribute.PosixFilePermission._ +import java.nio.file.{FileSystems, Files} +import java.util.{Timer, TimerTask} + +import org.apache.commons.compress.archivers.zip.{ZipArchiveEntry, ZipArchiveOutputStream, ZipFile} +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK + +import scala.collection.JavaConverters._ +import scala.collection.Set + +/** + * Utility methods. + */ +object Utils extends Logging { + private val posixFilePermissions = Array( + OWNER_READ, + OWNER_WRITE, + OWNER_EXECUTE, + GROUP_READ, + GROUP_WRITE, + GROUP_EXECUTE, + OTHERS_READ, + OTHERS_WRITE, + OTHERS_EXECUTE) + + val supportPosix: Boolean = + FileSystems.getDefault.supportedFileAttributeViews().contains("posix") + + /** + * Compress all files under given directory into one zip file and drop it to the target directory + * + * @param sourceDir source directory to zip + * @param targetZipFile target zip file + */ + def zip(sourceDir: File, targetZipFile: File): Unit = { + var fos: FileOutputStream = null + var zos: ZipArchiveOutputStream = null + try { + fos = new FileOutputStream(targetZipFile) + zos = new ZipArchiveOutputStream(fos) + + val sourcePath = sourceDir.toPath + FileUtils.listFiles(sourceDir, null, true).asScala.foreach { file => + var in: FileInputStream = null + try { + val path = file.toPath + val entry = new ZipArchiveEntry(sourcePath.relativize(path).toString) + if (supportPosix) { + entry.setUnixMode( + permissionsToMode(Files.getPosixFilePermissions(path).asScala) + | (if (entry.getName.endsWith(".exe")) 0x1ED else 0x1A4)) + } else if (entry.getName.endsWith(".exe")) { + entry.setUnixMode(0x1ED) // 755 + } else { + entry.setUnixMode(0x1A4) // 644 + } + zos.putArchiveEntry(entry) + + in = new FileInputStream(file) + IOUtils.copy(in, zos) + zos.closeArchiveEntry() + } finally { + IOUtils.closeQuietly(in) + } + } + } finally { + IOUtils.closeQuietly(zos) + IOUtils.closeQuietly(fos) + } + } + + /** + * Unzip a file to the given directory + * + * @param file file to be unzipped + * @param targetDir target directory + */ + def unzip(file: File, targetDir: File): Unit = { + var zipFile: ZipFile = null + try { + targetDir.mkdirs() + zipFile = new ZipFile(file) + zipFile.getEntries.asScala.foreach { entry => + val targetFile = new File(targetDir, entry.getName) + + if (targetFile.exists()) { + logWarning( + s"Target file/directory $targetFile already exists. Skip it for now. " + + s"Make sure this is expected.") + } else { + if (entry.isDirectory) { + targetFile.mkdirs() + } else { + targetFile.getParentFile.mkdirs() + val input = zipFile.getInputStream(entry) + val output = new FileOutputStream(targetFile) + IOUtils.copy(input, output) + IOUtils.closeQuietly(input) + IOUtils.closeQuietly(output) + if (supportPosix) { + val permissions = modeToPermissions(entry.getUnixMode) + // When run in Unix system, permissions will be empty, thus skip + // setting the empty permissions (which will empty the previous permissions). + if (permissions.nonEmpty) { + Files.setPosixFilePermissions(targetFile.toPath, permissions.asJava) + } + } + } + } + } + } catch { + case e: Exception => logError("exception caught during decompression:" + e) + } finally { + ZipFile.closeQuietly(zipFile) + } + } + + /** + * Exits the JVM, trying to do it nicely, otherwise doing it nastily. + * + * @param status the exit status, zero for OK, non-zero for error + * @param maxDelayMillis the maximum delay in milliseconds + */ + def exit(status: Int, maxDelayMillis: Long) { + try { + logInfo(s"Utils.exit() with status: $status, maxDelayMillis: $maxDelayMillis") + + // setup a timer, so if nice exit fails, the nasty exit happens + val timer = new Timer() + timer.schedule(new TimerTask() { + + override def run() { + Runtime.getRuntime.halt(status) + } + }, maxDelayMillis) + // try to exit nicely + System.exit(status); + } catch { + // exit nastily if we have a problem + case _: Throwable => Runtime.getRuntime.halt(status) + } finally { + // should never get here + Runtime.getRuntime.halt(status) + } + } + + /** + * Exits the JVM, trying to do it nicely, wait 1 second + * + * @param status the exit status, zero for OK, non-zero for error + */ + def exit(status: Int): Unit = { + exit(status, 1000) + } + + /** + * Normalize the Spark version by taking the first three numbers. + * For example: + * x.y.z => x.y.z + * x.y.z.xxx.yyy => x.y.z + * x.y => x.y + * + * @param version the Spark version to normalize + * @return Normalized Spark version. + */ + def normalizeSparkVersion(version: String): String = { + version + .split('.') + .take(3) + .zipWithIndex + .map({ + case (element, index) => { + index match { + case 2 => element.split("\\D+").lift(0).getOrElse("") + case _ => element + } + } + }) + .mkString(".") + } + + /** + * Validates the normalized spark version by verifying: + * - Spark version starts with sparkMajorMinorVersionPrefix. + * - If ignoreSparkPatchVersion is + * - true: valid + * - false: check if the spark version is in supportedSparkVersions. + * @param ignoreSparkPatchVersion Ignore spark patch version. + * @param sparkVersion The spark version. + * @param normalizedSparkVersion: The normalized spark version. + * @param supportedSparkMajorMinorVersionPrefix The spark major and minor version to validate against. + * @param supportedSparkVersions The set of supported spark versions. + */ + def validateSparkVersions( + ignoreSparkPatchVersion: Boolean, + sparkVersion: String, + normalizedSparkVersion: String, + supportedSparkMajorMinorVersionPrefix: String, + supportedSparkVersions: Set[String]): Unit = { + if (!normalizedSparkVersion.startsWith(s"$supportedSparkMajorMinorVersionPrefix.")) { + throw new IllegalArgumentException( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported spark major.minor version: '$supportedSparkMajorMinorVersionPrefix'.") + } else if (ignoreSparkPatchVersion) { + logWarning( + s"Ignoring spark patch version. Spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Spark major.minor prefix used: '$supportedSparkMajorMinorVersionPrefix'.") + } else if (!supportedSparkVersions(normalizedSparkVersion)) { + val supportedVersions = supportedSparkVersions.toSeq.sorted.mkString(", ") + throw new IllegalArgumentException( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported versions: '$supportedVersions'.") + } + } + + private[spark] def listZipFileEntries(file: File): Array[String] = { + var zipFile: ZipFile = null + try { + zipFile = new ZipFile(file) + zipFile.getEntries.asScala.map(_.getName).toArray + } finally { + ZipFile.closeQuietly(zipFile) + } + } + + private[this] def permissionsToMode(permissions: Set[PosixFilePermission]): Int = { + posixFilePermissions.foldLeft(0) { (mode, perm) => + (mode << 1) | (if (permissions.contains(perm)) 1 else 0) + } + } + + private[this] def modeToPermissions(mode: Int): Set[PosixFilePermission] = { + posixFilePermissions.zipWithIndex + .filter { case (_, i) => (mode & (0x100 >>> i)) != 0 } + .map(_._1) + .toSet + } +} diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..7088537e1 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import Extensions._ +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var handler: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + handler = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + serDe.writeInt(m, 1) // processId + serDe.writeInt(m, 1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = handler.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('j', reply.readByte()) + assertEquals(1, reply.readInt()) + val trackingId = new String(reply.readNBytes(1), "UTF-8") + assertEquals("1", trackingId) + val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull + assertEquals(classOf[CallbackClient], client.getClass) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..445486bbd --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var backend: DotnetBackend = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(backend.callbackClient.isDefined) + assertThrows(classOf[Exception], () => { + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + }) + } +} diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..c6904403b --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +private[dotnet] object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readNBytes(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..43ae79005 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) + + tracker.clear() + + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) + } + + @Test + def shouldResetCounter(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + + tracker.clear() + + val thirdId = tracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } +} diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..41401d680 --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,373 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions._ +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters._ + +@Test +class SerDeTest { + private var serDe: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + serDe = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val input = givenInput(in => { + in.writeByte('n') + }) + + assertEquals(null, serDe.readObject(input)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val input = givenInput(in => { + in.writeByte('_') + }) + + assertThrows(classOf[IllegalArgumentException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadInteger(): Unit = { + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) + }) + + assertEquals(42, serDe.readObject(input)) + } + + @Test + def shouldReadLong(): Unit = { + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) + }) + + assertEquals(42L, serDe.readObject(input)) + } + + @Test + def shouldReadDouble(): Unit = { + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) + }) + + assertEquals(42.42, serDe.readObject(input)) + } + + @Test + def shouldReadBoolean(): Unit = { + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) + }) + + assertEquals(true, serDe.readObject(input)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, serDe.readObject(input)) + } + + @Test + def shouldReadMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value + }) + + assertEquals( + mapAsJavaMap(Map( + 11 -> true, + 22 -> 42.42, + 33 -> null)), + serDe.readObject(input)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size + }) + + assertEquals(mapAsJavaMap(Map()), serDe.readObject(input)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type + }) + + assertThrows(classOf[IllegalArgumentException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadDate(): Unit = { + val input = givenInput(in => { + val date = "2020-12-31" + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val input = givenInput(in => { + val objectIndex = "1" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, serDe.readObject(input)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val input = givenInput(in => { + val objectIndex = "42" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows(classOf[NoSuchElementException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) + }) + + assertEquals( + seqAsJavaList(Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99)))), + serDe.readObject(input)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) + }) + + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) + } + + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + serDe.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + serDe.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + serDe.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + serDe.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } + + private def whenOutput = givenInput _ + + private def assertEndOfStream (in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } +} diff --git a/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala new file mode 100644 index 000000000..b8e0122fd --- /dev/null +++ b/src/scala/microsoft-spark-3-3/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.util.dotnet + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK +import org.junit.Assert.{assertEquals, assertThrows} +import org.junit.Test + +@Test +class UtilsTest { + + @Test + def shouldIgnorePatchVersion(): Unit = { + val sparkVersion = "3.3.0" + val sparkMajorMinorVersionPrefix = "3.3" + val supportedSparkVersions = Set[String]("3.3.0") + + Utils.validateSparkVersions( + true, + sparkVersion, + Utils.normalizeSparkVersion(sparkVersion), + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + } + + @Test + def shouldThrowForUnsupportedVersion(): Unit = { + val sparkVersion = "3.3.0" + val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion) + val sparkMajorMinorVersionPrefix = "3.3" + val supportedSparkVersions = Set[String]("3.3.0") + + val exception = assertThrows( + classOf[IllegalArgumentException], + () => { + Utils.validateSparkVersions( + false, + sparkVersion, + normalizedSparkVersion, + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + }) + + assertEquals( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported versions: '${supportedSparkVersions.toSeq.sorted.mkString(", ")}'.", + exception.getMessage) + } + + @Test + def shouldThrowForUnsupportedMajorMinorVersion(): Unit = { + val sparkVersion = "2.4.4" + val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion) + val sparkMajorMinorVersionPrefix = "3.3" + val supportedSparkVersions = Set[String]("3.3.0") + + val exception = assertThrows( + classOf[IllegalArgumentException], + () => { + Utils.validateSparkVersions( + false, + sparkVersion, + normalizedSparkVersion, + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + }) + + assertEquals( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported spark major.minor version: '$sparkMajorMinorVersionPrefix'.", + exception.getMessage) + } +} diff --git a/src/scala/microsoft-spark-3-4/pom.xml b/src/scala/microsoft-spark-3-4/pom.xml new file mode 100644 index 000000000..319c0c84d --- /dev/null +++ b/src/scala/microsoft-spark-3-4/pom.xml @@ -0,0 +1,83 @@ + + 4.0.0 + + com.microsoft.scala + microsoft-spark + ${microsoft-spark.version} + + microsoft-spark-3-4_2.12 + 2019 + + UTF-8 + 2.12.10 + 2.12 + 3.4.0 + + + + + org.scala-lang + scala-library + ${scala.version} + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${spark.version} + provided + + + junit + junit + 4.13.1 + test + + + org.specs + specs + 1.2.5 + test + + + + + src/main/scala + src/test/scala + + + org.scala-tools + maven-scala-plugin + 2.15.2 + + + + compile + testCompile + + + + + ${scala.version} + + -target:jvm-1.8 + -deprecation + -feature + + + + + + diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala new file mode 100644 index 000000000..aea355dfa --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackClient.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.DataOutputStream + +import org.apache.spark.internal.Logging + +import scala.collection.mutable.Queue + +/** + * CallbackClient is used to communicate with the Dotnet CallbackServer. + * The client manages and maintains a pool of open CallbackConnections. + * Any callback request is delegated to a new CallbackConnection or + * unused CallbackConnection. + * @param address The address of the Dotnet CallbackServer + * @param port The port of the Dotnet CallbackServer + */ +class CallbackClient(serDe: SerDe, address: String, port: Int) extends Logging { + private[this] val connectionPool: Queue[CallbackConnection] = Queue[CallbackConnection]() + + private[this] var isShutdown: Boolean = false + + final def send(callbackId: Int, writeBody: (DataOutputStream, SerDe) => Unit): Unit = + getOrCreateConnection() match { + case Some(connection) => + try { + connection.send(callbackId, writeBody) + addConnection(connection) + } catch { + case e: Exception => + logError(s"Error calling callback [callback id = $callbackId].", e) + connection.close() + throw e + } + case None => throw new Exception("Unable to get or create connection.") + } + + private def getOrCreateConnection(): Option[CallbackConnection] = synchronized { + if (isShutdown) { + logInfo("Cannot get or create connection while client is shutdown.") + return None + } + + if (connectionPool.nonEmpty) { + return Some(connectionPool.dequeue()) + } + + Some(new CallbackConnection(serDe, address, port)) + } + + private def addConnection(connection: CallbackConnection): Unit = synchronized { + assert(connection != null) + connectionPool.enqueue(connection) + } + + def shutdown(): Unit = synchronized { + if (isShutdown) { + logInfo("Shutdown called, but already shutdown.") + return + } + + logInfo("Shutting down.") + connectionPool.foreach(_.close) + connectionPool.clear + isShutdown = true + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala new file mode 100644 index 000000000..604cf029b --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/CallbackConnection.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.{ByteArrayOutputStream, Closeable, DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.internal.Logging + +/** + * CallbackConnection is used to process the callback communication + * between the JVM and Dotnet. It uses a TCP socket to communicate with + * the Dotnet CallbackServer and the socket is expected to be reused. + * @param address The address of the Dotnet CallbackServer + * @param port The port of the Dotnet CallbackServer + */ +class CallbackConnection(serDe: SerDe, address: String, port: Int) extends Logging { + private[this] val socket: Socket = new Socket(address, port) + private[this] val inputStream: DataInputStream = new DataInputStream(socket.getInputStream) + private[this] val outputStream: DataOutputStream = new DataOutputStream(socket.getOutputStream) + + def send( + callbackId: Int, + writeBody: (DataOutputStream, SerDe) => Unit): Unit = { + logInfo(s"Calling callback [callback id = $callbackId] ...") + + try { + serDe.writeInt(outputStream, CallbackFlags.CALLBACK) + serDe.writeInt(outputStream, callbackId) + + val byteArrayOutputStream = new ByteArrayOutputStream() + writeBody(new DataOutputStream(byteArrayOutputStream), serDe) + serDe.writeInt(outputStream, byteArrayOutputStream.size) + byteArrayOutputStream.writeTo(outputStream); + } catch { + case e: Exception => { + throw new Exception("Error writing to stream.", e) + } + } + + logInfo(s"Signaling END_OF_STREAM.") + try { + serDe.writeInt(outputStream, CallbackFlags.END_OF_STREAM) + outputStream.flush() + + val endOfStreamResponse = readFlag(inputStream) + endOfStreamResponse match { + case CallbackFlags.END_OF_STREAM => + logInfo(s"Received END_OF_STREAM signal. Calling callback [callback id = $callbackId] successful.") + case _ => { + throw new Exception(s"Error verifying end of stream. Expected: ${CallbackFlags.END_OF_STREAM}, " + + s"Received: $endOfStreamResponse") + } + } + } catch { + case e: Exception => { + throw new Exception("Error while verifying end of stream.", e) + } + } + } + + def close(): Unit = { + try { + serDe.writeInt(outputStream, CallbackFlags.CLOSE) + outputStream.flush() + } catch { + case e: Exception => logInfo("Unable to send close to .NET callback server.", e) + } + + close(socket) + close(outputStream) + close(inputStream) + } + + private def close(s: Socket): Unit = { + try { + assert(s != null) + s.close() + } catch { + case e: Exception => logInfo("Unable to close socket.", e) + } + } + + private def close(c: Closeable): Unit = { + try { + assert(c != null) + c.close() + } catch { + case e: Exception => logInfo("Unable to close closeable.", e) + } + } + + private def readFlag(inputStream: DataInputStream): Int = { + val callbackFlag = serDe.readInt(inputStream) + if (callbackFlag == CallbackFlags.DOTNET_EXCEPTION_THROWN) { + val exceptionMessage = serDe.readString(inputStream) + throw new DotnetException(exceptionMessage) + } + callbackFlag + } + + private object CallbackFlags { + val CLOSE: Int = -1 + val CALLBACK: Int = -2 + val DOTNET_EXCEPTION_THROWN: Int = -3 + val END_OF_STREAM: Int = -4 + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala new file mode 100644 index 000000000..c6f528aee --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackend.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.net.InetSocketAddress +import java.util.concurrent.TimeUnit +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_NUM_BACKEND_THREADS +import org.apache.spark.{SparkConf, SparkEnv} + +/** + * Netty server that invokes JVM calls based upon receiving messages from .NET. + * The implementation mirrors the RBackend. + * + */ +class DotnetBackend extends Logging { + self => // for accessing the this reference in inner class(ChannelInitializer) + private[this] var channelFuture: ChannelFuture = _ + private[this] var bootstrap: ServerBootstrap = _ + private[this] var bossGroup: EventLoopGroup = _ + private[this] val objectTracker = new JVMObjectTracker + + @volatile + private[dotnet] var callbackClient: Option[CallbackClient] = None + + def init(portNumber: Int): Int = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + val numBackendThreads = conf.get(DOTNET_NUM_BACKEND_THREADS) + logInfo(s"The number of DotnetBackend threads is set to $numBackendThreads.") + bossGroup = new NioEventLoopGroup(numBackendThreads) + val workerGroup = bossGroup + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(classOf[NioServerSocketChannel]) + + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { + def initChannel(ch: SocketChannel): Unit = { + ch.pipeline() + .addLast("encoder", new ByteArrayEncoder()) + .addLast( + "frameDecoder", + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 4 + // lengthAdjustment = 0 + // initialBytesToStrip = 4, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) + .addLast("decoder", new ByteArrayDecoder()) + .addLast("handler", new DotnetBackendHandler(self, objectTracker)) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress("localhost", portNumber)) + channelFuture.syncUninterruptibly() + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort + } + + private[dotnet] def setCallbackClient(address: String, port: Int): Unit = synchronized { + callbackClient = callbackClient match { + case Some(_) => throw new Exception("Callback client already set.") + case None => + logInfo(s"Connecting to a callback server at $address:$port") + Some(new CallbackClient(new SerDe(objectTracker), address, port)) + } + } + + private[dotnet] def shutdownCallbackClient(): Unit = synchronized { + callbackClient match { + case Some(client) => client.shutdown() + case None => logInfo("Callback server has already been shutdown.") + } + callbackClient = None + } + + def run(): Unit = { + channelFuture.channel.closeFuture().syncUninterruptibly() + } + + def close(): Unit = { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS) + channelFuture = null + } + if (bootstrap != null && bootstrap.config().group() != null) { + bootstrap.config().group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.config().childGroup() != null) { + bootstrap.config().childGroup().shutdownGracefully() + } + bootstrap = null + + objectTracker.clear() + + // Send close to .NET callback server. + shutdownCallbackClient() + + // Shutdown the thread pool whose executors could still be running. + ThreadPool.shutdown() + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala new file mode 100644 index 000000000..fc3a5c911 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetBackendHandler.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.collection.mutable.HashMap +import scala.language.existentials + +/** + * Handler for DotnetBackend. + * This implementation is similar to RBackendHandler. + */ +class DotnetBackendHandler(server: DotnetBackend, objectsTracker: JVMObjectTracker) + extends SimpleChannelInboundHandler[Array[Byte]] + with Logging { + + private[this] val serDe = new SerDe(objectsTracker) + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + val reply = handleBackendRequest(msg) + ctx.write(reply) + } + + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { + ctx.flush() + } + + def handleBackendRequest(msg: Array[Byte]): Array[Byte] = { + val bis = new ByteArrayInputStream(msg) + val dis = new DataInputStream(bis) + + val bos = new ByteArrayOutputStream() + val dos = new DataOutputStream(bos) + + // First bit is isStatic + val isStatic = serDe.readBoolean(dis) + val processId = serDe.readInt(dis) + val threadId = serDe.readInt(dis) + val objId = serDe.readString(dis) + val methodName = serDe.readString(dis) + val numArgs = serDe.readInt(dis) + + if (objId == "DotnetHandler") { + methodName match { + case "stopBackend" => + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") + server.close() + case "rm" => + try { + val t = serDe.readObjectType(dis) + assert(t == 'c') + val objToRemove = serDe.readString(dis) + objectsTracker.remove(objToRemove) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, null) + } catch { + case e: Exception => + logError(s"Removing $objId failed", e) + serDe.writeInt(dos, -1) + } + case "rmThread" => + try { + assert(serDe.readObjectType(dis) == 'i') + val processId = serDe.readInt(dis) + assert(serDe.readObjectType(dis) == 'i') + val threadToDelete = serDe.readInt(dis) + val result = ThreadPool.tryDeleteThread(processId, threadToDelete) + serDe.writeInt(dos, 0) + serDe.writeObject(dos, result.asInstanceOf[AnyRef]) + } catch { + case e: Exception => + logError(s"Removing thread $threadId failed", e) + serDe.writeInt(dos, -1) + } + case "connectCallback" => + assert(serDe.readObjectType(dis) == 'c') + val address = serDe.readString(dis) + assert(serDe.readObjectType(dis) == 'i') + val port = serDe.readInt(dis) + server.setCallbackClient(address, port) + serDe.writeInt(dos, 0) + + // Sends reference of CallbackClient to dotnet side, + // so that dotnet process can send the client back to Java side + // when calling any API containing callback functions. + serDe.writeObject(dos, server.callbackClient) + case "closeCallback" => + logInfo("Requesting to close callback client") + server.shutdownCallbackClient() + serDe.writeInt(dos, 0) + serDe.writeType(dos, "void") + case _ => dos.writeInt(-1) + } + } else { + ThreadPool + .run(processId, threadId, () => handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)) + } + + bos.toByteArray + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Skip logging the exception message if the connection was disconnected from + // the .NET side so that .NET side doesn't have to explicitly close the connection via + // "stopBackend." Note that an exception is still thrown if the exit status is non-zero, + // so skipping this kind of exception message does not affect the debugging. + if (!cause.getMessage.contains( + "An existing connection was forcibly closed by the remote host")) { + logError("Exception caught: ", cause) + } + + // Close the connection when an exception is raised. + ctx.close() + } + + def handleMethodCall( + isStatic: Boolean, + objId: String, + methodName: String, + numArgs: Int, + dis: DataInputStream, + dos: DataOutputStream): Unit = { + var obj: Object = null + var args: Array[java.lang.Object] = null + var methods: Array[java.lang.reflect.Method] = null + + try { + val cls = if (isStatic) { + Utils.classForName(objId) + } else { + objectsTracker.get(objId) match { + case None => throw new IllegalArgumentException("Object not found " + objId) + case Some(o) => + obj = o + o.getClass + } + } + + args = readArgs(numArgs, dis) + methods = cls.getMethods + + val selectedMethods = methods.filter(m => m.getName == methodName) + if (selectedMethods.length > 0) { + val index = findMatchedSignature(selectedMethods.map(_.getParameterTypes), args) + + if (index.isEmpty) { + logWarning( + s"cannot find matching method ${cls}.$methodName. " + + s"Candidates are:") + selectedMethods.foreach { method => + logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})") + } + throw new Exception(s"No matched method found for $cls.$methodName") + } + + val ret = selectedMethods(index.get).invoke(obj, args: _*) + + // Write status bit + serDe.writeInt(dos, 0) + serDe.writeObject(dos, ret.asInstanceOf[AnyRef]) + } else if (methodName == "") { + // methodName should be "" for constructor + val ctor = cls.getConstructors.filter { x => + matchMethod(numArgs, args, x.getParameterTypes) + }.head + + val obj = ctor.newInstance(args: _*) + + serDe.writeInt(dos, 0) + serDe.writeObject(dos, obj.asInstanceOf[AnyRef]) + } else { + throw new IllegalArgumentException( + "invalid method " + methodName + " for object " + objId) + } + } catch { + case e: Throwable => + val jvmObj = objectsTracker.get(objId) + val jvmObjName = jvmObj match { + case Some(jObj) => jObj.getClass.getName + case None => "NullObject" + } + val argsStr = args + .map(arg => { + if (arg != null) { + s"[Type=${arg.getClass.getCanonicalName}, Value: $arg]" + } else { + "[Value: NULL]" + } + }) + .mkString(", ") + + logError(s"Failed to execute '$methodName' on '$jvmObjName' with args=($argsStr)") + + if (methods != null) { + logDebug(s"All methods for $jvmObjName:") + methods.foreach(m => logDebug(m.toString)) + } + + serDe.writeInt(dos, -1) + serDe.writeString(dos, Utils.exceptionString(e.getCause)) + } + } + + // Read a number of arguments from the data input stream + def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { + (0 until numArgs).map { arg => + serDe.readObject(dis) + }.toArray + } + + // Checks if the arguments passed in args matches the parameter types. + // NOTE: Currently we do exact match. We may add type conversions later. + def matchMethod( + numArgs: Int, + args: Array[java.lang.Object], + parameterTypes: Array[Class[_]]): Boolean = { + if (parameterTypes.length != numArgs) { + return false + } + + for (i <- 0 until numArgs) { + val parameterType = parameterTypes(i) + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + + if (!parameterWrapperType.isInstance(args(i))) { + // non primitive types + if (!parameterType.isPrimitive && args(i) != null) { + return false + } + + // primitive types + if (parameterType.isPrimitive && !parameterWrapperType.isInstance(args(i))) { + return false + } + } + } + + true + } + + // Find a matching method signature in an array of signatures of constructors + // or methods of the same name according to the passed arguments. Arguments + // may be converted in order to match a signature. + // + // Note that in Java reflection, constructors and normal methods are of different + // classes, and share no parent class that provides methods for reflection uses. + // There is no unified way to handle them in this function. So an array of signatures + // is passed in instead of an array of candidate constructors or methods. + // + // Returns an Option[Int] which is the index of the matched signature in the array. + def findMatchedSignature( + parameterTypesOfMethods: Array[Array[Class[_]]], + args: Array[Object]): Option[Int] = { + val numArgs = args.length + + for (index <- parameterTypesOfMethods.indices) { + val parameterTypes = parameterTypesOfMethods(index) + + if (parameterTypes.length == numArgs) { + var argMatched = true + var i = 0 + while (i < numArgs && argMatched) { + val parameterType = parameterTypes(i) + + if (parameterType == classOf[Seq[Any]] && args(i).getClass.isArray) { + // The case that the parameter type is a Scala Seq and the argument + // is a Java array is considered matching. The array will be converted + // to a Seq later if this method is matched. + } else { + var parameterWrapperType = parameterType + + // Convert native parameters to Object types as args is Array[Object] here + if (parameterType.isPrimitive) { + parameterWrapperType = parameterType match { + case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Long] + case java.lang.Double.TYPE => classOf[java.lang.Double] + case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] + case _ => parameterType + } + } + if ((parameterType.isPrimitive || args(i) != null) && + !parameterWrapperType.isInstance(args(i))) { + argMatched = false + } + } + + i = i + 1 + } + + if (argMatched) { + // For now, we return the first matching method. + // TODO: find best method in matching methods. + + // Convert args if needed + val parameterTypes = parameterTypesOfMethods(index) + + for (i <- 0 until numArgs) { + if (parameterTypes(i) == classOf[Seq[Any]] && args(i).getClass.isArray) { + // Convert a Java array to scala Seq + args(i) = args(i).asInstanceOf[Array[_]].toSeq + } + } + + return Some(index) + } + } + } + None + } + + def logError(id: String, e: Exception): Unit = {} +} + + diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala new file mode 100644 index 000000000..c70d16b03 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetException.scala @@ -0,0 +1,13 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +class DotnetException(message: String, cause: Throwable) + extends Exception(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala new file mode 100644 index 000000000..f5277c215 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetRDD.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python._ +import org.apache.spark.rdd.RDD + +object DotnetRDD { + def createPythonRDD( + parent: RDD[_], + func: PythonFunction, + preservePartitoning: Boolean): PythonRDD = { + new PythonRDD(parent, func, preservePartitoning) + } + + def createJavaRDDFromArray( + sc: SparkContext, + arr: Array[Array[Byte]], + numSlices: Int): JavaRDD[Array[Byte]] = { + JavaRDD.fromRDD(sc.parallelize(arr, numSlices)) + } + + def toJavaRDD(rdd: RDD[_]): JavaRDD[_] = JavaRDD.fromRDD(rdd) +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala new file mode 100644 index 000000000..9f556338b --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/DotnetUtils.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import scala.collection.JavaConverters._ + +/** DotnetUtils object that hosts some helper functions + * help data type conversions between dotnet and scala + */ +object DotnetUtils { + + /** A helper function to convert scala Map to java.util.Map + * @param value - scala Map + * @return java.util.Map + */ + def convertToJavaMap(value: Map[_, _]): java.util.Map[_, _] = value.asJava + + /** Convert java data type to corresponding scala type + * @param value - java.lang.Object + * @return Any + */ + def mapScalaToJava(value: java.lang.Object): Any = { + value match { + case i: java.lang.Integer => i.toInt + case d: java.lang.Double => d.toDouble + case f: java.lang.Float => f.toFloat + case b: java.lang.Boolean => b.booleanValue() + case l: java.lang.Long => l.toLong + case s: java.lang.Short => s.toShort + case by: java.lang.Byte => by.toByte + case c: java.lang.Character => c.toChar + case _ => value + } + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala new file mode 100644 index 000000000..81cfaf88b --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JVMObjectTracker.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import scala.collection.mutable.HashMap + +/** + * Tracks JVM objects returned to .NET which is useful for invoking calls from .NET on JVM objects. + */ +private[dotnet] class JVMObjectTracker { + + // Multiple threads may access objMap and increase objCounter. Because get method return Option, + // it is convenient to use a Scala map instead of java.util.concurrent.ConcurrentHashMap. + private[this] val objMap = new HashMap[String, Object] + private[this] var objCounter: Int = 1 + + def getObject(id: String): Object = { + synchronized { + objMap(id) + } + } + + def get(id: String): Option[Object] = { + synchronized { + objMap.get(id) + } + } + + def put(obj: Object): String = { + synchronized { + val objId = objCounter.toString + objCounter = objCounter + 1 + objMap.put(objId, obj) + objId + } + } + + def remove(id: String): Option[Object] = { + synchronized { + objMap.remove(id) + } + } + + def clear(): Unit = { + synchronized { + objMap.clear() + objCounter = 1 + } + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala new file mode 100644 index 000000000..06a476f67 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/JvmBridgeUtils.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import org.apache.spark.SparkConf + +/* + * Utils for JvmBridge. + */ +object JvmBridgeUtils { + def getKeyValuePairAsString(kvp: (String, String)): String = { + return kvp._1 + "=" + kvp._2 + } + + def getKeyValuePairArrayAsString(kvpArray: Array[(String, String)]): String = { + val sb = new StringBuilder + + for (kvp <- kvpArray) { + sb.append(getKeyValuePairAsString(kvp)) + sb.append(";") + } + + sb.toString + } + + def getSparkConfAsString(sparkConf: SparkConf): String = { + getKeyValuePairArrayAsString(sparkConf.getAll) + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala new file mode 100644 index 000000000..a3df3788a --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/SerDe.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.io.{DataInputStream, DataOutputStream} +import java.nio.charset.StandardCharsets +import java.sql.{Date, Time, Timestamp} + +import org.apache.spark.sql.Row + +import scala.collection.JavaConverters._ + +/** + * Class responsible for serialization and deserialization between CLR & JVM. + * This implementation of methods is mostly identical to the SerDe implementation in R. + */ +class SerDe(val tracker: JVMObjectTracker) { + + def readObjectType(dis: DataInputStream): Char = { + dis.readByte().toChar + } + + def readObject(dis: DataInputStream): Object = { + val dataType = readObjectType(dis) + readTypedObject(dis, dataType) + } + + private def readTypedObject(dis: DataInputStream, dataType: Char): Object = { + dataType match { + case 'n' => null + case 'i' => new java.lang.Integer(readInt(dis)) + case 'g' => new java.lang.Long(readLong(dis)) + case 'd' => new java.lang.Double(readDouble(dis)) + case 'b' => new java.lang.Boolean(readBoolean(dis)) + case 'c' => readString(dis) + case 'e' => readMap(dis) + case 'r' => readBytes(dis) + case 'l' => readList(dis) + case 'D' => readDate(dis) + case 't' => readTime(dis) + case 'j' => tracker.getObject(readString(dis)) + case 'R' => readRowArr(dis) + case 'O' => readObjectArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") + } + } + + private def readBytes(in: DataInputStream): Array[Byte] = { + val len = readInt(in) + val out = new Array[Byte](len) + in.readFully(out) + out + } + + def readInt(in: DataInputStream): Int = { + in.readInt() + } + + private def readLong(in: DataInputStream): Long = { + in.readLong() + } + + private def readDouble(in: DataInputStream): Double = { + in.readDouble() + } + + private def readStringBytes(in: DataInputStream, len: Int): String = { + val bytes = new Array[Byte](len) + in.readFully(bytes) + val str = new String(bytes, "UTF-8") + str + } + + def readString(in: DataInputStream): String = { + val len = in.readInt() + readStringBytes(in, len) + } + + def readBoolean(in: DataInputStream): Boolean = { + in.readBoolean() + } + + private def readDate(in: DataInputStream): Date = { + Date.valueOf(readString(in)) + } + + private def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t + } + + private def readRow(in: DataInputStream): Row = { + val len = readInt(in) + Row.fromSeq((0 until len).map(_ => readObject(in))) + } + + private def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { + val len = readInt(in) + (0 until len).map(_ => readBytes(in)).toArray + } + + private def readIntArr(in: DataInputStream): Array[Int] = { + val len = readInt(in) + (0 until len).map(_ => readInt(in)).toArray + } + + private def readLongArr(in: DataInputStream): Array[Long] = { + val len = readInt(in) + (0 until len).map(_ => readLong(in)).toArray + } + + private def readDoubleArr(in: DataInputStream): Array[Double] = { + val len = readInt(in) + (0 until len).map(_ => readDouble(in)).toArray + } + + private def readDoubleArrArr(in: DataInputStream): Array[Array[Double]] = { + val len = readInt(in) + (0 until len).map(_ => readDoubleArr(in)).toArray + } + + private def readBooleanArr(in: DataInputStream): Array[Boolean] = { + val len = readInt(in) + (0 until len).map(_ => readBoolean(in)).toArray + } + + private def readStringArr(in: DataInputStream): Array[String] = { + val len = readInt(in) + (0 until len).map(_ => readString(in)).toArray + } + + private def readRowArr(in: DataInputStream): java.util.List[Row] = { + val len = readInt(in) + (0 until len).map(_ => readRow(in)).toList.asJava + } + + private def readObjectArr(in: DataInputStream): Seq[Any] = { + val len = readInt(in) + (0 until len).map(_ => readObject(in)) + } + + private def readList(dis: DataInputStream): Array[_] = { + val arrType = readObjectType(dis) + arrType match { + case 'i' => readIntArr(dis) + case 'g' => readLongArr(dis) + case 'c' => readStringArr(dis) + case 'd' => readDoubleArr(dis) + case 'A' => readDoubleArrArr(dis) + case 'b' => readBooleanArr(dis) + case 'j' => readStringArr(dis).map(x => tracker.getObject(x)) + case 'r' => readBytesArr(dis) + case _ => throw new IllegalArgumentException(s"Invalid array type $arrType") + } + } + + private def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + val len = readInt(in) + if (len > 0) { + val keysType = readObjectType(in) + val keysLen = readInt(in) + val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) + + val valuesLen = readInt(in) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) + keys.zip(values).toMap.asJava + } else { + new java.util.HashMap[Object, Object]() + } + } + + // Using the same mapping as SparkR implementation for now + // Methods to write out data from Java to .NET. + // + // Type mapping from Java to .NET: + // + // void -> NULL + // Int -> integer + // String -> character + // Boolean -> logical + // Float -> double + // Double -> double + // Long -> long + // Array[Byte] -> raw + // Date -> Date + // Time -> POSIXct + // + // Array[T] -> list() + // Object -> jobj + + def writeType(dos: DataOutputStream, typeStr: String): Unit = { + typeStr match { + case "void" => dos.writeByte('n') + case "character" => dos.writeByte('c') + case "double" => dos.writeByte('d') + case "doublearray" => dos.writeByte('A') + case "long" => dos.writeByte('g') + case "integer" => dos.writeByte('i') + case "logical" => dos.writeByte('b') + case "date" => dos.writeByte('D') + case "time" => dos.writeByte('t') + case "raw" => dos.writeByte('r') + case "list" => dos.writeByte('l') + case "jobj" => dos.writeByte('j') + case _ => throw new IllegalArgumentException(s"Invalid type $typeStr") + } + } + + def writeObject(dos: DataOutputStream, value: Object): Unit = { + if (value == null || value == Unit) { + writeType(dos, "void") + } else { + value.getClass.getName match { + case "java.lang.String" => + writeType(dos, "character") + writeString(dos, value.asInstanceOf[String]) + case "float" | "java.lang.Float" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Float].toDouble) + case "double" | "java.lang.Double" => + writeType(dos, "double") + writeDouble(dos, value.asInstanceOf[Double]) + case "long" | "java.lang.Long" => + writeType(dos, "long") + writeLong(dos, value.asInstanceOf[Long]) + case "int" | "java.lang.Integer" => + writeType(dos, "integer") + writeInt(dos, value.asInstanceOf[Int]) + case "boolean" | "java.lang.Boolean" => + writeType(dos, "logical") + writeBoolean(dos, value.asInstanceOf[Boolean]) + case "java.sql.Date" => + writeType(dos, "date") + writeDate(dos, value.asInstanceOf[Date]) + case "java.sql.Time" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) + case "[B" => + writeType(dos, "raw") + writeBytes(dos, value.asInstanceOf[Array[Byte]]) + // TODO: Types not handled right now include + // byte, char, short, float + + // Handle arrays + case "[Ljava.lang.String;" => + writeType(dos, "list") + writeStringArr(dos, value.asInstanceOf[Array[String]]) + case "[I" => + writeType(dos, "list") + writeIntArr(dos, value.asInstanceOf[Array[Int]]) + case "[J" => + writeType(dos, "list") + writeLongArr(dos, value.asInstanceOf[Array[Long]]) + case "[D" => + writeType(dos, "list") + writeDoubleArr(dos, value.asInstanceOf[Array[Double]]) + case "[[D" => + writeType(dos, "list") + writeDoubleArrArr(dos, value.asInstanceOf[Array[Array[Double]]]) + case "[Z" => + writeType(dos, "list") + writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]]) + case "[[B" => + writeType(dos, "list") + writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]]) + case otherName => + // Handle array of objects + if (otherName.startsWith("[L")) { + val objArr = value.asInstanceOf[Array[Object]] + writeType(dos, "list") + writeType(dos, "jobj") + dos.writeInt(objArr.length) + objArr.foreach(o => writeJObj(dos, o)) + } else { + writeType(dos, "jobj") + writeJObj(dos, value) + } + } + } + } + + def writeInt(out: DataOutputStream, value: Int): Unit = { + out.writeInt(value) + } + + def writeLong(out: DataOutputStream, value: Long): Unit = { + out.writeLong(value) + } + + private def writeDouble(out: DataOutputStream, value: Double): Unit = { + out.writeDouble(value) + } + + private def writeBoolean(out: DataOutputStream, value: Boolean): Unit = { + out.writeBoolean(value) + } + + private def writeDate(out: DataOutputStream, value: Date): Unit = { + writeString(out, value.toString) + } + + private def writeTime(out: DataOutputStream, value: Time): Unit = { + out.writeDouble(value.getTime.toDouble / 1000.0) + } + + private def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } + + def writeString(out: DataOutputStream, value: String): Unit = { + val utf8 = value.getBytes(StandardCharsets.UTF_8) + val len = utf8.length + out.writeInt(len) + out.write(utf8, 0, len) + } + + private def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = { + out.writeInt(value.length) + out.write(value) + } + + def writeJObj(out: DataOutputStream, value: Object): Unit = { + val objId = tracker.put(value) + writeString(out, objId) + } + + private def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { + writeType(out, "integer") + out.writeInt(value.length) + value.foreach(v => out.writeInt(v)) + } + + private def writeLongArr(out: DataOutputStream, value: Array[Long]): Unit = { + writeType(out, "long") + out.writeInt(value.length) + value.foreach(v => out.writeLong(v)) + } + + private def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = { + writeType(out, "double") + out.writeInt(value.length) + value.foreach(v => out.writeDouble(v)) + } + + private def writeDoubleArrArr(out: DataOutputStream, value: Array[Array[Double]]): Unit = { + writeType(out, "doublearray") + out.writeInt(value.length) + value.foreach(v => writeDoubleArr(out, v)) + } + + private def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = { + writeType(out, "logical") + out.writeInt(value.length) + value.foreach(v => writeBoolean(out, v)) + } + + private def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = { + writeType(out, "character") + out.writeInt(value.length) + value.foreach(v => writeString(out, v)) + } + + private def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = { + writeType(out, "raw") + out.writeInt(value.length) + value.foreach(v => writeBytes(out, v)) + } +} + +private object SerializationFormats { + val BYTE = "byte" + val STRING = "string" + val ROW = "row" +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala new file mode 100644 index 000000000..50551a7d9 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/api/dotnet/ThreadPool.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import java.util.concurrent.{ExecutorService, Executors} + +import scala.collection.mutable + +/** + * Pool of thread executors. There should be a 1-1 correspondence between C# threads + * and Java threads. + */ +object ThreadPool { + + /** + * Map from (processId, threadId) to corresponding executor. + */ + private val executors: mutable.HashMap[(Int, Int), ExecutorService] = + new mutable.HashMap[(Int, Int), ExecutorService]() + + /** + * Run some code on a particular thread. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @param task Function to run on the thread. + */ + def run(processId: Int, threadId: Int, task: () => Unit): Unit = { + val executor = getOrCreateExecutor(processId, threadId) + val future = executor.submit(new Runnable { + override def run(): Unit = task() + }) + + future.get() + } + + /** + * Try to delete a particular thread. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @return True if successful, false if thread does not exist. + */ + def tryDeleteThread(processId: Int, threadId: Int): Boolean = synchronized { + executors.remove((processId, threadId)) match { + case Some(executorService) => + executorService.shutdown() + true + case None => false + } + } + + /** + * Shutdown any running ExecutorServices. + */ + def shutdown(): Unit = synchronized { + executors.foreach(_._2.shutdown()) + executors.clear() + } + + /** + * Get the executor if it exists, otherwise create a new one. + * @param processId Integer id of the process. + * @param threadId Integer id of the thread. + * @return The new or existing executor with the given id. + */ + private def getOrCreateExecutor(processId: Int, threadId: Int): ExecutorService = synchronized { + executors.getOrElseUpdate((processId, threadId), Executors.newSingleThreadExecutor) + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala new file mode 100644 index 000000000..4551a70bd --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotNetUserAppException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.deploy.dotnet + +import org.apache.spark.SparkException + +/** + * This exception type describes an exception thrown by a .NET user application. + * + * @param exitCode Exit code returned by the .NET application. + * @param dotNetStackTrace Stacktrace extracted from .NET application logs. + */ +private[spark] class DotNetUserAppException(exitCode: Int, dotNetStackTrace: Option[String]) + extends SparkException( + dotNetStackTrace match { + case None => s"User application exited with $exitCode" + case Some(e) => s"User application exited with $exitCode and .NET exception: $e" + }) diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala new file mode 100644 index 000000000..2d1c2cf3b --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/deploy/dotnet/DotnetRunner.scala @@ -0,0 +1,309 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.deploy.dotnet + +import java.io.File +import java.net.URI +import java.nio.file.attribute.PosixFilePermissions +import java.nio.file.{FileSystems, Files, Paths} +import java.util.Locale +import java.util.concurrent.{Semaphore, TimeUnit} + +import org.apache.commons.io.FilenameUtils +import org.apache.commons.io.output.TeeOutputStream +import org.apache.hadoop.fs.Path +import org.apache.spark +import org.apache.spark.api.dotnet.DotnetBackend +import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.{ + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK, + ERROR_BUFFER_SIZE, ERROR_REDIRECITON_ENABLED +} +import org.apache.spark.util.dotnet.{Utils => DotnetUtils} +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.{SecurityManager, SparkConf, SparkUserAppException} + +import scala.collection.JavaConverters._ +import scala.io.StdIn +import scala.util.Try + +/** + * DotnetRunner class used to launch Spark .NET applications using spark-submit. + * It executes .NET application as a subprocess and then has it connect back to + * the JVM to access system properties etc. + */ +object DotnetRunner extends Logging { + private val DEBUG_PORT = 5567 + private val supportedSparkMajorMinorVersionPrefix = "3.4" + private val supportedSparkVersions = Set[String]("3.2.0", "3.2.1", "3.2.2", "3.2.3", "3.3.0", "3.3.1", "3.3.2", "3.4.0") + + val SPARK_VERSION = DotnetUtils.normalizeSparkVersion(spark.SPARK_VERSION) + + def main(args: Array[String]): Unit = { + if (args.length == 0) { + throw new IllegalArgumentException("At least one argument is expected.") + } + + DotnetUtils.validateSparkVersions( + sys.props + .getOrElse( + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.key, + DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK.defaultValue.get.toString) + .toBoolean, + spark.SPARK_VERSION, + SPARK_VERSION, + supportedSparkMajorMinorVersionPrefix, + supportedSparkVersions) + + val settings = initializeSettings(args) + + // Determines if this needs to be run in debug mode. + // In debug mode this runner will not launch a .NET process. + val runInDebugMode = settings._1 + @volatile var dotnetBackendPortNumber = settings._2 + var dotnetExecutable = "" + var otherArgs: Array[String] = null + + if (!runInDebugMode) { + if (args(0).toLowerCase(Locale.ROOT).endsWith(".zip")) { + var zipFileName = args(0) + val zipFileUri = Try(new URI(zipFileName)).getOrElse(new File(zipFileName).toURI) + val workingDir = new File("").getAbsoluteFile + val driverDir = new File(workingDir, FilenameUtils.getBaseName(zipFileUri.getPath())) + + // Standalone cluster mode where .NET application is remotely located. + if (zipFileUri.getScheme() != "file") { + zipFileName = downloadDriverFile(zipFileName, workingDir.getAbsolutePath).getName + } + + logInfo(s"Unzipping .NET driver $zipFileName to $driverDir") + DotnetUtils.unzip(new File(zipFileName), driverDir) + + // Reuse windows-specific formatting in PythonRunner. + dotnetExecutable = PythonRunner.formatPath(resolveDotnetExecutable(driverDir, args(1))) + otherArgs = args.slice(2, args.length) + } else { + // Reuse windows-specific formatting in PythonRunner. + dotnetExecutable = PythonRunner.formatPath(args(0)) + otherArgs = args.slice(1, args.length) + } + } else { + otherArgs = args.slice(1, args.length) + } + + val processParameters = new java.util.ArrayList[String] + processParameters.add(dotnetExecutable) + otherArgs.foreach(arg => processParameters.add(arg)) + + logInfo(s"Starting DotnetBackend with $dotnetExecutable.") + + // Time to wait for DotnetBackend to initialize in seconds. + val backendTimeout = sys.env.getOrElse("DOTNETBACKEND_TIMEOUT", "120").toInt + + // Launch a DotnetBackend server for the .NET process to connect to; this will let it see our + // Java system properties etc. + val dotnetBackend = new DotnetBackend() + val initialized = new Semaphore(0) + val dotnetBackendThread = new Thread("DotnetBackend") { + override def run() { + // need to get back dotnetBackendPortNumber because if the value passed to init is 0 + // the port number is dynamically assigned in the backend + dotnetBackendPortNumber = dotnetBackend.init(dotnetBackendPortNumber) + logInfo(s"Port number used by DotnetBackend is $dotnetBackendPortNumber") + initialized.release() + dotnetBackend.run() + } + } + + dotnetBackendThread.start() + + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + if (!runInDebugMode) { + var returnCode = -1 + var process: Process = null + val enableLogRedirection: Boolean = sys.props + .getOrElse( + ERROR_REDIRECITON_ENABLED.key, + ERROR_REDIRECITON_ENABLED.defaultValue.get.toString).toBoolean + val stderrBuffer: Option[CircularBuffer] = Option(enableLogRedirection).collect { + case true => new CircularBuffer( + sys.props.getOrElse( + ERROR_BUFFER_SIZE.key, + ERROR_BUFFER_SIZE.defaultValue.get.toString).toInt) + } + + try { + val builder = new ProcessBuilder(processParameters) + val env = builder.environment() + env.put("DOTNETBACKEND_PORT", dotnetBackendPortNumber.toString) + + for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) { + env.put(key, value) + logInfo(s"Adding key=$key and value=$value to environment") + } + builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize + process = builder.start() + + // Redirect stdin of JVM process to stdin of .NET process. + new RedirectThread(System.in, process.getOutputStream, "redirect JVM input").start() + // Redirect stdout and stderr of .NET process to System.out and to buffer + // if log direction is enabled. If not, redirect only to System.out. + new RedirectThread( + process.getInputStream, + stderrBuffer match { + case Some(buffer) => new TeeOutputStream(System.out, buffer) + case _ => System.out + }, + "redirect .NET stdout and stderr").start() + + process.waitFor() + } catch { + case t: Throwable => + logThrowable(t) + } finally { + returnCode = closeDotnetProcess(process) + closeBackend(dotnetBackend) + } + if (returnCode != 0) { + if (stderrBuffer.isDefined) { + throw new DotNetUserAppException(returnCode, Some(stderrBuffer.get.toString)) + } else { + throw new SparkUserAppException(returnCode) + } + } else { + logInfo(s".NET application exited successfully") + } + // TODO: The following is causing the following error: + // INFO ApplicationMaster: Final app status: FAILED, exitCode: 16, + // (reason: Shutdown hook called before final status was reported.) + // DotnetUtils.exit(returnCode) + } else { + // scalastyle:off println + println("***********************************************************************") + println("* .NET Backend running debug mode. Press enter to exit *") + println("***********************************************************************") + // scalastyle:on println + + StdIn.readLine() + closeBackend(dotnetBackend) + DotnetUtils.exit(0) + } + } else { + logError(s"DotnetBackend did not initialize in $backendTimeout seconds") + DotnetUtils.exit(-1) + } + } + + // When the executable is downloaded as part of zip file, check if the file exists + // after zip file is unzipped under the given dir. Once it is found, change the + // permission to executable (only for Unix systems, since the zip file may have been + // created under Windows. Finally, the absolute path for the executable is returned. + private def resolveDotnetExecutable(dir: File, dotnetExecutable: String): String = { + val path = Paths.get(dir.getAbsolutePath, dotnetExecutable) + val resolvedExecutable = if (Files.isRegularFile(path)) { + path.toAbsolutePath.toString + } else { + Files + .walk(FileSystems.getDefault.getPath(dir.getAbsolutePath)) + .iterator() + .asScala + .find(path => Files.isRegularFile(path) && path.getFileName.toString == dotnetExecutable) match { + case Some(path) => path.toAbsolutePath.toString + case None => + throw new IllegalArgumentException( + s"Failed to find $dotnetExecutable under ${dir.getAbsolutePath}") + } + } + + if (DotnetUtils.supportPosix) { + Files.setPosixFilePermissions( + Paths.get(resolvedExecutable), + PosixFilePermissions.fromString("rwxr-xr-x")) + } + + resolvedExecutable + } + + /** + * Download HDFS file into the supplied directory and return its local path. + * Will throw an exception if there are errors during downloading. + */ + private def downloadDriverFile(hdfsFilePath: String, driverDir: String): File = { + val sparkConf = new SparkConf() + val filePath = new Path(hdfsFilePath) + + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val jarFileName = filePath.getName + val localFile = new File(driverDir, jarFileName) + + if (!localFile.exists()) { // May already exist if running multiple workers on one node + logInfo(s"Copying user file $filePath to $driverDir") + Utils.fetchFile( + hdfsFilePath, + new File(driverDir), + sparkConf, + hadoopConf, + System.currentTimeMillis(), + useCache = false) + } + + if (!localFile.exists()) { + throw new Exception(s"Did not see expected $jarFileName in $driverDir") + } + + localFile + } + + private def closeBackend(dotnetBackend: DotnetBackend): Unit = { + logInfo("Closing DotnetBackend") + dotnetBackend.close() + } + + private def closeDotnetProcess(dotnetProcess: Process): Int = { + if (dotnetProcess == null) { + return -1 + } else if (!dotnetProcess.isAlive) { + return dotnetProcess.exitValue() + } + + // Try to (gracefully on Linux) kill the process and resort to force if interrupted + var returnCode = -1 + logInfo("Closing .NET process") + try { + dotnetProcess.destroy() + returnCode = dotnetProcess.waitFor() + } catch { + case _: InterruptedException => + logInfo( + "Thread interrupted while waiting for graceful close. Forcefully closing .NET process") + returnCode = dotnetProcess.destroyForcibly().waitFor() + case t: Throwable => + logThrowable(t) + } + + returnCode + } + + private def initializeSettings(args: Array[String]): (Boolean, Int) = { + val runInDebugMode = (args.length == 1 || args.length == 2) && args(0).equalsIgnoreCase( + "debug") + var portNumber = 0 + if (runInDebugMode) { + if (args.length == 1) { + portNumber = DEBUG_PORT + } else if (args.length == 2) { + portNumber = Integer.parseInt(args(1)) + } + } + + (runInDebugMode, portNumber) + } + + private def logThrowable(throwable: Throwable): Unit = + logError(s"${throwable.getMessage} \n ${throwable.getStackTrace.mkString("\n")}") +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala new file mode 100644 index 000000000..18ba4c6e5 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/internal/config/dotnet/Dotnet.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.internal.config.dotnet + +import org.apache.spark.internal.config.ConfigBuilder + +private[spark] object Dotnet { + val DOTNET_NUM_BACKEND_THREADS = ConfigBuilder("spark.dotnet.numDotnetBackendThreads").intConf + .createWithDefault(10) + + val DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK = + ConfigBuilder("spark.dotnet.ignoreSparkPatchVersionCheck").booleanConf + .createWithDefault(false) + + val ERROR_REDIRECITON_ENABLED = + ConfigBuilder("spark.nonjvm.error.forwarding.enabled").booleanConf + .createWithDefault(false) + + val ERROR_BUFFER_SIZE = + ConfigBuilder("spark.nonjvm.error.buffer.size") + .intConf + .checkValue(_ >= 0, "The error buffer size must not be negative") + .createWithDefault(10240) +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala new file mode 100644 index 000000000..3e3c3e0e3 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/mllib/api/dotnet/MLUtils.scala @@ -0,0 +1,26 @@ + +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.mllib.api.dotnet + +import org.apache.spark.ml._ +import scala.collection.JavaConverters._ + +/** MLUtils object that hosts helper functions + * related to ML usage + */ +object MLUtils { + + /** A helper function to let pipeline accept java.util.ArrayList + * format stages in scala code + * @param pipeline - The pipeline to be set stages + * @param value - A java.util.ArrayList of PipelineStages to be set as stages + * @return The pipeline + */ + def setPipelineStages(pipeline: Pipeline, value: java.util.ArrayList[_ <: PipelineStage]): Pipeline = + pipeline.setStages(value.asScala.toArray) +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala new file mode 100644 index 000000000..5d06d4304 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/DotnetForeachBatch.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import org.apache.spark.api.dotnet.CallbackClient +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.streaming.DataStreamWriter + +class DotnetForeachBatchFunction(callbackClient: CallbackClient, callbackId: Int) extends Logging { + def call(batchDF: DataFrame, batchId: Long): Unit = + callbackClient.send( + callbackId, + (dos, serDe) => { + serDe.writeJObj(dos, batchDF) + serDe.writeLong(dos, batchId) + }) +} + +object DotnetForeachBatchHelper { + def callForeachBatch(client: Option[CallbackClient], dsw: DataStreamWriter[Row], callbackId: Int): Unit = { + val dotnetForeachFunc = client match { + case Some(value) => new DotnetForeachBatchFunction(value, callbackId) + case None => throw new Exception("CallbackClient is null.") + } + + dsw.foreachBatch(dotnetForeachFunc.call _) + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala new file mode 100644 index 000000000..b5e97289a --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/api/dotnet/SQLUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.api.dotnet + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.api.python.{PythonAccumulatorV2, PythonBroadcast, PythonFunction} +import org.apache.spark.broadcast.Broadcast + +object SQLUtils { + + /** + * Exposes createPythonFunction to the .NET client to enable registering UDFs. + */ + def createPythonFunction( + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVersion: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: PythonAccumulatorV2): PythonFunction = { + + PythonFunction( + command, + envVars, + pythonIncludes, + pythonExec, + pythonVersion, + broadcastVars, + accumulator) + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/test/TestUtils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/test/TestUtils.scala new file mode 100644 index 000000000..1cd45aa95 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/sql/test/TestUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.sql.test + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.MemoryStream + +object TestUtils { + + /** + * Helper method to create typed MemoryStreams intended for use in unit tests. + * @param sqlContext The SQLContext. + * @param streamType The type of memory stream to create. This string is the `Name` + * property of the dotnet type. + * @return A typed MemoryStream. + */ + def createMemoryStream(implicit sqlContext: SQLContext, streamType: String): MemoryStream[_] = { + import sqlContext.implicits._ + + streamType match { + case "Int32" => MemoryStream[Int] + case "String" => MemoryStream[String] + case _ => throw new Exception(s"$streamType not supported") + } + } +} diff --git a/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/util/dotnet/Utils.scala b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/util/dotnet/Utils.scala new file mode 100644 index 000000000..f9400789a --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/main/scala/org/apache/spark/util/dotnet/Utils.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.util.dotnet + +import java.io._ +import java.nio.file.attribute.PosixFilePermission +import java.nio.file.attribute.PosixFilePermission._ +import java.nio.file.{FileSystems, Files} +import java.util.{Timer, TimerTask} + +import org.apache.commons.compress.archivers.zip.{ZipArchiveEntry, ZipArchiveOutputStream, ZipFile} +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK + +import scala.collection.JavaConverters._ +import scala.collection.Set + +/** + * Utility methods. + */ +object Utils extends Logging { + private val posixFilePermissions = Array( + OWNER_READ, + OWNER_WRITE, + OWNER_EXECUTE, + GROUP_READ, + GROUP_WRITE, + GROUP_EXECUTE, + OTHERS_READ, + OTHERS_WRITE, + OTHERS_EXECUTE) + + val supportPosix: Boolean = + FileSystems.getDefault.supportedFileAttributeViews().contains("posix") + + /** + * Compress all files under given directory into one zip file and drop it to the target directory + * + * @param sourceDir source directory to zip + * @param targetZipFile target zip file + */ + def zip(sourceDir: File, targetZipFile: File): Unit = { + var fos: FileOutputStream = null + var zos: ZipArchiveOutputStream = null + try { + fos = new FileOutputStream(targetZipFile) + zos = new ZipArchiveOutputStream(fos) + + val sourcePath = sourceDir.toPath + FileUtils.listFiles(sourceDir, null, true).asScala.foreach { file => + var in: FileInputStream = null + try { + val path = file.toPath + val entry = new ZipArchiveEntry(sourcePath.relativize(path).toString) + if (supportPosix) { + entry.setUnixMode( + permissionsToMode(Files.getPosixFilePermissions(path).asScala) + | (if (entry.getName.endsWith(".exe")) 0x1ED else 0x1A4)) + } else if (entry.getName.endsWith(".exe")) { + entry.setUnixMode(0x1ED) // 755 + } else { + entry.setUnixMode(0x1A4) // 644 + } + zos.putArchiveEntry(entry) + + in = new FileInputStream(file) + IOUtils.copy(in, zos) + zos.closeArchiveEntry() + } finally { + IOUtils.closeQuietly(in) + } + } + } finally { + IOUtils.closeQuietly(zos) + IOUtils.closeQuietly(fos) + } + } + + /** + * Unzip a file to the given directory + * + * @param file file to be unzipped + * @param targetDir target directory + */ + def unzip(file: File, targetDir: File): Unit = { + var zipFile: ZipFile = null + try { + targetDir.mkdirs() + zipFile = new ZipFile(file) + zipFile.getEntries.asScala.foreach { entry => + val targetFile = new File(targetDir, entry.getName) + + if (targetFile.exists()) { + logWarning( + s"Target file/directory $targetFile already exists. Skip it for now. " + + s"Make sure this is expected.") + } else { + if (entry.isDirectory) { + targetFile.mkdirs() + } else { + targetFile.getParentFile.mkdirs() + val input = zipFile.getInputStream(entry) + val output = new FileOutputStream(targetFile) + IOUtils.copy(input, output) + IOUtils.closeQuietly(input) + IOUtils.closeQuietly(output) + if (supportPosix) { + val permissions = modeToPermissions(entry.getUnixMode) + // When run in Unix system, permissions will be empty, thus skip + // setting the empty permissions (which will empty the previous permissions). + if (permissions.nonEmpty) { + Files.setPosixFilePermissions(targetFile.toPath, permissions.asJava) + } + } + } + } + } + } catch { + case e: Exception => logError("exception caught during decompression:" + e) + } finally { + ZipFile.closeQuietly(zipFile) + } + } + + /** + * Exits the JVM, trying to do it nicely, otherwise doing it nastily. + * + * @param status the exit status, zero for OK, non-zero for error + * @param maxDelayMillis the maximum delay in milliseconds + */ + def exit(status: Int, maxDelayMillis: Long) { + try { + logInfo(s"Utils.exit() with status: $status, maxDelayMillis: $maxDelayMillis") + + // setup a timer, so if nice exit fails, the nasty exit happens + val timer = new Timer() + timer.schedule(new TimerTask() { + + override def run() { + Runtime.getRuntime.halt(status) + } + }, maxDelayMillis) + // try to exit nicely + System.exit(status); + } catch { + // exit nastily if we have a problem + case _: Throwable => Runtime.getRuntime.halt(status) + } finally { + // should never get here + Runtime.getRuntime.halt(status) + } + } + + /** + * Exits the JVM, trying to do it nicely, wait 1 second + * + * @param status the exit status, zero for OK, non-zero for error + */ + def exit(status: Int): Unit = { + exit(status, 1000) + } + + /** + * Normalize the Spark version by taking the first three numbers. + * For example: + * x.y.z => x.y.z + * x.y.z.xxx.yyy => x.y.z + * x.y => x.y + * + * @param version the Spark version to normalize + * @return Normalized Spark version. + */ + def normalizeSparkVersion(version: String): String = { + version + .split('.') + .take(3) + .zipWithIndex + .map({ + case (element, index) => { + index match { + case 2 => element.split("\\D+").lift(0).getOrElse("") + case _ => element + } + } + }) + .mkString(".") + } + + /** + * Validates the normalized spark version by verifying: + * - Spark version starts with sparkMajorMinorVersionPrefix. + * - If ignoreSparkPatchVersion is + * - true: valid + * - false: check if the spark version is in supportedSparkVersions. + * @param ignoreSparkPatchVersion Ignore spark patch version. + * @param sparkVersion The spark version. + * @param normalizedSparkVersion: The normalized spark version. + * @param supportedSparkMajorMinorVersionPrefix The spark major and minor version to validate against. + * @param supportedSparkVersions The set of supported spark versions. + */ + def validateSparkVersions( + ignoreSparkPatchVersion: Boolean, + sparkVersion: String, + normalizedSparkVersion: String, + supportedSparkMajorMinorVersionPrefix: String, + supportedSparkVersions: Set[String]): Unit = { + if (!normalizedSparkVersion.startsWith(s"$supportedSparkMajorMinorVersionPrefix.")) { + throw new IllegalArgumentException( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported spark major.minor version: '$supportedSparkMajorMinorVersionPrefix'.") + } else if (ignoreSparkPatchVersion) { + logWarning( + s"Ignoring spark patch version. Spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Spark major.minor prefix used: '$supportedSparkMajorMinorVersionPrefix'.") + } else if (!supportedSparkVersions(normalizedSparkVersion)) { + val supportedVersions = supportedSparkVersions.toSeq.sorted.mkString(", ") + throw new IllegalArgumentException( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported versions: '$supportedVersions'.") + } + } + + private[spark] def listZipFileEntries(file: File): Array[String] = { + var zipFile: ZipFile = null + try { + zipFile = new ZipFile(file) + zipFile.getEntries.asScala.map(_.getName).toArray + } finally { + ZipFile.closeQuietly(zipFile) + } + } + + private[this] def permissionsToMode(permissions: Set[PosixFilePermission]): Int = { + posixFilePermissions.foldLeft(0) { (mode, perm) => + (mode << 1) | (if (permissions.contains(perm)) 1 else 0) + } + } + + private[this] def modeToPermissions(mode: Int): Set[PosixFilePermission] = { + posixFilePermissions.zipWithIndex + .filter { case (_, i) => (mode & (0x100 >>> i)) != 0 } + .map(_._1) + .toSet + } +} diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala new file mode 100644 index 000000000..7088537e1 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendHandlerTest.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import Extensions._ +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +@Test +class DotnetBackendHandlerTest { + private var backend: DotnetBackend = _ + private var tracker: JVMObjectTracker = _ + private var handler: DotnetBackendHandler = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + tracker = new JVMObjectTracker + handler = new DotnetBackendHandler(backend, tracker) + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldTrackCallbackClientWhenDotnetProcessConnected(): Unit = { + val message = givenMessage(m => { + val serDe = new SerDe(null) + m.writeBoolean(true) // static method + serDe.writeInt(m, 1) // processId + serDe.writeInt(m, 1) // threadId + serDe.writeString(m, "DotnetHandler") // class name + serDe.writeString(m, "connectCallback") // command (method) name + m.writeInt(2) // number of arguments + m.writeByte('c') // 1st argument type (string) + serDe.writeString(m, "127.0.0.1") // 1st argument value (host) + m.writeByte('i') // 2nd argument type (integer) + m.writeInt(0) // 2nd argument value (port) + }) + + val payload = handler.handleBackendRequest(message) + val reply = new DataInputStream(new ByteArrayInputStream(payload)) + + assertEquals( + "status code must be successful.", 0, reply.readInt()) + assertEquals('j', reply.readByte()) + assertEquals(1, reply.readInt()) + val trackingId = new String(reply.readNBytes(1), "UTF-8") + assertEquals("1", trackingId) + val client = tracker.get(trackingId).get.asInstanceOf[Option[CallbackClient]].orNull + assertEquals(classOf[CallbackClient], client.getClass) + } + + private def givenMessage(func: DataOutputStream => Unit): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + func(new DataOutputStream(buffer)) + buffer.toByteArray + } +} diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala new file mode 100644 index 000000000..445486bbd --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/DotnetBackendTest.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import org.junit.Assert._ +import org.junit.{After, Before, Test} + +import java.net.InetAddress + +@Test +class DotnetBackendTest { + private var backend: DotnetBackend = _ + + @Before + def before(): Unit = { + backend = new DotnetBackend + } + + @After + def after(): Unit = { + backend.close() + } + + @Test + def shouldNotResetCallbackClient(): Unit = { + // Specifying port = 0 to select port dynamically. + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + + assertTrue(backend.callbackClient.isDefined) + assertThrows(classOf[Exception], () => { + backend.setCallbackClient(InetAddress.getLoopbackAddress.toString, port = 0) + }) + } +} diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala new file mode 100644 index 000000000..c6904403b --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/Extensions.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + + +package org.apache.spark.api.dotnet + +import java.io.DataInputStream + +private[dotnet] object Extensions { + implicit class DataInputStreamExt(stream: DataInputStream) { + def readNBytes(n: Int): Array[Byte] = { + val buf = new Array[Byte](n) + stream.readFully(buf) + buf + } + } +} diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala new file mode 100644 index 000000000..43ae79005 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/JVMObjectTrackerTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.junit.Test + +@Test +class JVMObjectTrackerTest { + + @Test + def shouldReleaseAllReferences(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + val thirdId = tracker.put(new Object) + + tracker.clear() + + assert(tracker.get(firstId).isEmpty) + assert(tracker.get(secondId).isEmpty) + assert(tracker.get(thirdId).isEmpty) + } + + @Test + def shouldResetCounter(): Unit = { + val tracker = new JVMObjectTracker + val firstId = tracker.put(new Object) + val secondId = tracker.put(new Object) + + tracker.clear() + + val thirdId = tracker.put(new Object) + + assert(firstId.equals("1")) + assert(secondId.equals("2")) + assert(thirdId.equals("1")) + } +} diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala new file mode 100644 index 000000000..41401d680 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/api/dotnet/SerDeTest.scala @@ -0,0 +1,373 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.api.dotnet + +import org.apache.spark.api.dotnet.Extensions._ +import org.apache.spark.sql.Row +import org.junit.Assert._ +import org.junit.{Before, Test} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.sql.Date +import scala.collection.JavaConverters._ + +@Test +class SerDeTest { + private var serDe: SerDe = _ + private var tracker: JVMObjectTracker = _ + + @Before + def before(): Unit = { + tracker = new JVMObjectTracker + serDe = new SerDe(tracker) + } + + @Test + def shouldReadNull(): Unit = { + val input = givenInput(in => { + in.writeByte('n') + }) + + assertEquals(null, serDe.readObject(input)) + } + + @Test + def shouldThrowForUnsupportedTypes(): Unit = { + val input = givenInput(in => { + in.writeByte('_') + }) + + assertThrows(classOf[IllegalArgumentException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadInteger(): Unit = { + val input = givenInput(in => { + in.writeByte('i') + in.writeInt(42) + }) + + assertEquals(42, serDe.readObject(input)) + } + + @Test + def shouldReadLong(): Unit = { + val input = givenInput(in => { + in.writeByte('g') + in.writeLong(42) + }) + + assertEquals(42L, serDe.readObject(input)) + } + + @Test + def shouldReadDouble(): Unit = { + val input = givenInput(in => { + in.writeByte('d') + in.writeDouble(42.42) + }) + + assertEquals(42.42, serDe.readObject(input)) + } + + @Test + def shouldReadBoolean(): Unit = { + val input = givenInput(in => { + in.writeByte('b') + in.writeBoolean(true) + }) + + assertEquals(true, serDe.readObject(input)) + } + + @Test + def shouldReadString(): Unit = { + val payload = "Spark Dotnet" + val input = givenInput(in => { + in.writeByte('c') + in.writeInt(payload.getBytes("UTF-8").length) + in.write(payload.getBytes("UTF-8")) + }) + + assertEquals(payload, serDe.readObject(input)) + } + + @Test + def shouldReadMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(3) // size + in.writeByte('i') // key type + in.writeInt(3) // number of keys + in.writeInt(11) // first key + in.writeInt(22) // second key + in.writeInt(33) // third key + in.writeInt(3) // number of values + in.writeByte('b') // first value type + in.writeBoolean(true) // first value + in.writeByte('d') // second value type + in.writeDouble(42.42) // second value + in.writeByte('n') // third type & value + }) + + assertEquals( + mapAsJavaMap(Map( + 11 -> true, + 22 -> 42.42, + 33 -> null)), + serDe.readObject(input)) + } + + @Test + def shouldReadEmptyMap(): Unit = { + val input = givenInput(in => { + in.writeByte('e') // map type descriptor + in.writeInt(0) // size + }) + + assertEquals(mapAsJavaMap(Map()), serDe.readObject(input)) + } + + @Test + def shouldReadBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(3) // length + in.write(Array[Byte](1, 2, 3)) // payload + }) + + assertArrayEquals(Array[Byte](1, 2, 3), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyBytesArray(): Unit = { + val input = givenInput(in => { + in.writeByte('r') // byte array type descriptor + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Byte](), serDe.readObject(input).asInstanceOf[Array[Byte]]) + } + + @Test + def shouldReadEmptyList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('i') // element type + in.writeInt(0) // length + }) + + assertArrayEquals(Array[Int](), serDe.readObject(input).asInstanceOf[Array[Int]]) + } + + @Test + def shouldReadList(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('b') // element type + in.writeInt(3) // length + in.writeBoolean(true) + in.writeBoolean(false) + in.writeBoolean(true) + }) + + assertArrayEquals(Array(true, false, true), serDe.readObject(input).asInstanceOf[Array[Boolean]]) + } + + @Test + def shouldThrowWhenReadingListWithUnsupportedType(): Unit = { + val input = givenInput(in => { + in.writeByte('l') // type descriptor + in.writeByte('_') // unsupported element type + }) + + assertThrows(classOf[IllegalArgumentException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadDate(): Unit = { + val input = givenInput(in => { + val date = "2020-12-31" + in.writeByte('D') // type descriptor + in.writeInt(date.getBytes("UTF-8").length) // date string size + in.write(date.getBytes("UTF-8")) + }) + + assertEquals(Date.valueOf("2020-12-31"), serDe.readObject(input)) + } + + @Test + def shouldReadObject(): Unit = { + val trackingObject = new Object + tracker.put(trackingObject) + val input = givenInput(in => { + val objectIndex = "1" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertSame(trackingObject, serDe.readObject(input)) + } + + @Test + def shouldThrowWhenReadingNonTrackingObject(): Unit = { + val input = givenInput(in => { + val objectIndex = "42" + in.writeByte('j') // type descriptor + in.writeInt(objectIndex.getBytes("UTF-8").length) // size + in.write(objectIndex.getBytes("UTF-8")) + }) + + assertThrows(classOf[NoSuchElementException], () => { + serDe.readObject(input) + }) + } + + @Test + def shouldReadSparkRows(): Unit = { + val input = givenInput(in => { + in.writeByte('R') // type descriptor + in.writeInt(2) // number of rows + in.writeInt(1) // number of elements in 1st row + in.writeByte('i') // type of 1st element in 1st row + in.writeInt(11) + in.writeInt(3) // number of elements in 2st row + in.writeByte('b') // type of 1st element in 2nd row + in.writeBoolean(true) + in.writeByte('d') // type of 2nd element in 2nd row + in.writeDouble(42.24) + in.writeByte('g') // type of 3nd element in 2nd row + in.writeLong(99) + }) + + assertEquals( + seqAsJavaList(Seq( + Row.fromSeq(Seq(11)), + Row.fromSeq(Seq(true, 42.24, 99)))), + serDe.readObject(input)) + } + + @Test + def shouldReadArrayOfObjects(): Unit = { + val input = givenInput(in => { + in.writeByte('O') // type descriptor + in.writeInt(2) // number of elements + in.writeByte('i') // type of 1st element + in.writeInt(42) + in.writeByte('b') // type of 2nd element + in.writeBoolean(true) + }) + + assertEquals(Seq(42, true), serDe.readObject(input).asInstanceOf[Seq[Any]]) + } + + @Test + def shouldWriteNull(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, null) + serDe.writeObject(out, Unit) + }) + + assertEquals(in.readByte(), 'n') + assertEquals(in.readByte(), 'n') + assertEndOfStream(in) + } + + @Test + def shouldWriteString(): Unit = { + val sparkDotnet = "Spark Dotnet" + val in = whenOutput(out => { + serDe.writeObject(out, sparkDotnet) + }) + + assertEquals(in.readByte(), 'c') // object type + assertEquals(in.readInt(), sparkDotnet.length) // length + assertArrayEquals(in.readNBytes(sparkDotnet.length), sparkDotnet.getBytes("UTF-8")) + assertEndOfStream(in) + } + + @Test + def shouldWritePrimitiveTypes(): Unit = { + val in = whenOutput(out => { + serDe.writeObject(out, 42.24f.asInstanceOf[Object]) + serDe.writeObject(out, 42L.asInstanceOf[Object]) + serDe.writeObject(out, 42.asInstanceOf[Object]) + serDe.writeObject(out, true.asInstanceOf[Object]) + }) + + assertEquals(in.readByte(), 'd') + assertEquals(in.readDouble(), 42.24F, 0.000001) + assertEquals(in.readByte(), 'g') + assertEquals(in.readLong(), 42L) + assertEquals(in.readByte(), 'i') + assertEquals(in.readInt(), 42) + assertEquals(in.readByte(), 'b') + assertEquals(in.readBoolean(), true) + assertEndOfStream(in) + } + + @Test + def shouldWriteDate(): Unit = { + val date = "2020-12-31" + val in = whenOutput(out => { + serDe.writeObject(out, Date.valueOf(date)) + }) + + assertEquals(in.readByte(), 'D') // type + assertEquals(in.readInt(), 10) // size + assertArrayEquals(in.readNBytes(10), date.getBytes("UTF-8")) // content + } + + @Test + def shouldWriteCustomObjects(): Unit = { + val customObject = new Object + val in = whenOutput(out => { + serDe.writeObject(out, customObject) + }) + + assertEquals(in.readByte(), 'j') + assertEquals(in.readInt(), 1) + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) + assertSame(tracker.get("1").get, customObject) + } + + @Test + def shouldWriteArrayOfCustomObjects(): Unit = { + val payload = Array(new Object, new Object) + val in = whenOutput(out => { + serDe.writeObject(out, payload) + }) + + assertEquals(in.readByte(), 'l') // array type + assertEquals(in.readByte(), 'j') // type of element in array + assertEquals(in.readInt(), 2) // array length + assertEquals(in.readInt(), 1) // size of 1st element's identifiers + assertArrayEquals(in.readNBytes(1), "1".getBytes("UTF-8")) // identifier of 1st element + assertEquals(in.readInt(), 1) // size of 2nd element's identifier + assertArrayEquals(in.readNBytes(1), "2".getBytes("UTF-8")) // identifier of 2nd element + assertSame(tracker.get("1").get, payload(0)) + assertSame(tracker.get("2").get, payload(1)) + } + + private def givenInput(func: DataOutputStream => Unit): DataInputStream = { + val buffer = new ByteArrayOutputStream() + val out = new DataOutputStream(buffer) + func(out) + new DataInputStream(new ByteArrayInputStream(buffer.toByteArray)) + } + + private def whenOutput = givenInput _ + + private def assertEndOfStream (in: DataInputStream): Unit = { + assertEquals(-1, in.read()) + } +} diff --git a/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala new file mode 100644 index 000000000..9d4379436 --- /dev/null +++ b/src/scala/microsoft-spark-3-4/src/test/scala/org/apache/spark/util/dotnet/UtilsTest.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the .NET Foundation under one or more agreements. + * The .NET Foundation licenses this file to you under the MIT license. + * See the LICENSE file in the project root for more information. + */ + +package org.apache.spark.util.dotnet + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.dotnet.Dotnet.DOTNET_IGNORE_SPARK_PATCH_VERSION_CHECK +import org.junit.Assert.{assertEquals, assertThrows} +import org.junit.Test + +@Test +class UtilsTest { + + @Test + def shouldIgnorePatchVersion(): Unit = { + val sparkVersion = "3.4.0" + val sparkMajorMinorVersionPrefix = "3.4" + val supportedSparkVersions = Set[String]("3.4.0") + + Utils.validateSparkVersions( + true, + sparkVersion, + Utils.normalizeSparkVersion(sparkVersion), + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + } + + @Test + def shouldThrowForUnsupportedVersion(): Unit = { + val sparkVersion = "3.4.0" + val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion) + val sparkMajorMinorVersionPrefix = "3.4" + val supportedSparkVersions = Set[String]("3.4.0") + + val exception = assertThrows( + classOf[IllegalArgumentException], + () => { + Utils.validateSparkVersions( + false, + sparkVersion, + normalizedSparkVersion, + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + }) + + assertEquals( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported versions: '${supportedSparkVersions.toSeq.sorted.mkString(", ")}'.", + exception.getMessage) + } + + @Test + def shouldThrowForUnsupportedMajorMinorVersion(): Unit = { + val sparkVersion = "2.4.4" + val normalizedSparkVersion = Utils.normalizeSparkVersion(sparkVersion) + val sparkMajorMinorVersionPrefix = "3.4" + val supportedSparkVersions = Set[String]("3.4.0") + + val exception = assertThrows( + classOf[IllegalArgumentException], + () => { + Utils.validateSparkVersions( + false, + sparkVersion, + normalizedSparkVersion, + sparkMajorMinorVersionPrefix, + supportedSparkVersions) + }) + + assertEquals( + s"Unsupported spark version used: '$sparkVersion'. " + + s"Normalized spark version used: '$normalizedSparkVersion'. " + + s"Supported spark major.minor version: '$sparkMajorMinorVersionPrefix'.", + exception.getMessage) + } +} diff --git a/src/scala/pom.xml b/src/scala/pom.xml index 890af92b8..53fe6d4f2 100644 --- a/src/scala/pom.xml +++ b/src/scala/pom.xml @@ -15,6 +15,8 @@ microsoft-spark-3-0 microsoft-spark-3-1 microsoft-spark-3-2 + microsoft-spark-3-3 + microsoft-spark-3-4